1# mypy: allow-untyped-defs 2import collections 3import functools 4import inspect 5import sys 6import textwrap 7import types 8import warnings 9from typing import Dict, List, Set, Type 10 11import torch 12import torch._jit_internal as _jit_internal 13from torch._sources import fake_range 14from torch.jit._builtins import _find_builtin 15from torch.jit._check import AttributeTypeIsSupportedChecker 16from torch.jit._state import _add_script_class, _get_script_class, _python_cu 17from torch.jit.frontend import ( 18 get_class_properties, 19 get_default_args, 20 get_jit_class_def, 21 get_jit_def, 22) 23from torch.nn import Module 24 25 26ScriptMethodStub = collections.namedtuple( 27 "ScriptMethodStub", ("resolution_callback", "def_", "original_method") 28) 29PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_")) 30 31 32# TODO: there should be a more principled way of doing this. 33ignored_attributes = [ 34 "_version", 35 "_parameters", 36 "_buffers", 37 "_non_persistent_buffers_set", 38 "_backward_hooks", 39 "_backward_pre_hooks", 40 "_forward_hooks", 41 "_forward_hooks_with_kwargs", 42 "_forward_pre_hooks", 43 "_forward_pre_hooks_with_kwargs", 44 "_forward_hooks_always_called", 45 "_state_dict_hooks", 46 "_state_dict_pre_hooks", 47 "_load_state_dict_pre_hooks", 48 "_load_state_dict_post_hooks", 49 "_modules", 50 "_initializing", 51 "dump_patches", 52] 53 54 55def _compile_and_register_class(obj, rcb, qualified_name): 56 script_class = _get_script_class(obj) 57 58 if not script_class: 59 ast = get_jit_class_def(obj, obj.__name__) 60 defaults = torch.jit.frontend.get_default_args_for_class(obj) 61 script_class = torch._C._jit_script_class_compile( 62 qualified_name, ast, defaults, rcb 63 ) 64 _add_script_class(obj, script_class) 65 66 return script_class 67 68 69def make_stub(func, name): 70 rcb = _jit_internal.createResolutionCallbackFromClosure(func) 71 ast = get_jit_def(func, name, self_name="RecursiveScriptModule") 72 return ScriptMethodStub(rcb, ast, func) 73 74 75def make_stub_from_method(nn_module, method_name): 76 func = getattr(nn_module, method_name) 77 if isinstance(func, ScriptMethodStub): 78 return func 79 # Make sure the name present in the resulting AST will match the name 80 # requested here. The only time they don't match is if you do something 81 # like: 82 # def _forward(self): 83 # pass 84 # forward = _forward 85 # In this case, the actual function object will have the name `_forward`, 86 # even though we requested a stub for `forward`. 87 return make_stub(func, method_name) 88 89 90def make_stubs_from_exported_methods(mod): 91 stubs = [] 92 for name in dir(mod): 93 item = getattr(mod, name, None) 94 if ( 95 _jit_internal.get_torchscript_modifier(item) 96 is _jit_internal.FunctionModifiers.EXPORT 97 ): 98 stubs.append(make_stub_from_method(mod, name)) 99 100 return stubs 101 102 103def jit_ignored_properties(module): 104 user_annotated_ignored_attributes = getattr( 105 module, "__jit_ignored_attributes__", [] 106 ) 107 108 def get_properties_names(module): 109 return {k for k, v in vars(module).items() if isinstance(v, property)} 110 111 properties = get_properties_names(type(module)) 112 user_annoted_ignored_properties = set() 113 114 for ignored_attr in user_annotated_ignored_attributes: 115 if ignored_attr in properties: 116 user_annoted_ignored_properties.add(ignored_attr) 117 return user_annoted_ignored_properties 118 119 120# base types that can be constants 121# in addition, tuples and lists of these base types are also considered constants 122# If you edit this list, then you also need to edit the handlers in 123# ConstantValue in jit/script/init.cpp 124_constant_types = ( 125 bool, 126 float, 127 int, 128 str, 129 type(None), 130 torch.device, 131 torch.layout, 132 torch.dtype, 133) 134 135 136def _get_valid_constant(attr, v, owner_type): 137 if isinstance(v, _constant_types): 138 return v 139 elif isinstance(v, (tuple, list)): 140 return tuple(_get_valid_constant(attr, x, owner_type) for x in v) 141 constants = ", ".join(torch.typename(typ) for typ in _constant_types) 142 raise TypeError( 143 textwrap.dedent( 144 f""" 145 '{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant. 146 Valid constants are: 147 1. a nn.ModuleList 148 2. a value of type {{{constants}}} 149 3. a list or tuple of (2) 150 """ 151 ) 152 ) 153 154 155class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): 156 def __init__(self, source, filename, file_lineno, leading_whitespace_len): 157 super().__init__(source, filename, file_lineno, leading_whitespace_len) 158 159 160def get_annotations(obj): 161 if sys.version_info < (3, 10): 162 return getattr(obj, "__annotations__", {}) 163 # In Python-3.10+ it is recommended to use inspect.get_annotations 164 # See https://docs.python.org/3.10/howto/annotations.html 165 # But also, in 3.10 annotations from base class are not inherited 166 # by unannotated derived one, so they must be manually extracted 167 annotations = inspect.get_annotations(obj) 168 if annotations: 169 return annotations 170 171 def get_cls_annotations(cls): 172 cls_annotations = inspect.get_annotations(cls) 173 if cls_annotations: 174 return cls_annotations 175 for base in cls.__bases__: 176 cls_annotations = get_cls_annotations(base) 177 if cls_annotations: 178 return cls_annotations 179 return {} 180 181 cls = obj if isinstance(obj, type) else type(obj) 182 return get_cls_annotations(cls) 183 184 185def infer_concrete_type_builder(nn_module, share_types=True): 186 """ 187 Build a ConcreteModuleTypeBuilder from an nn.Module. 188 189 This ConcreteModuleType doesn't have a JIT type associated with it yet, it 190 must be filled in by the caller. 191 """ 192 concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module)) 193 if isinstance(nn_module, (torch.nn.ModuleDict)): 194 concrete_type_builder.set_module_dict() 195 if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)): 196 concrete_type_builder.set_module_list() 197 if isinstance(nn_module, (torch.nn.ParameterList)): 198 concrete_type_builder.set_parameter_list() 199 if isinstance(nn_module, (torch.nn.ParameterDict)): 200 concrete_type_builder.set_parameter_dict() 201 202 class_annotations = get_annotations(nn_module) 203 if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)): 204 class_annotations = {} 205 206 # Get user-annotated ignored attributes. 207 user_annotated_ignored_attributes = getattr( 208 nn_module, "__jit_ignored_attributes__", [] 209 ) 210 concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes) 211 ignored_properties = jit_ignored_properties(nn_module) 212 213 # try to infer the type from type annotation or from the object itself 214 def infer_type(name, item): 215 # The forward function from Module is special; never use this annotations; we 216 # need to infer type directly using JIT. I originally wanted to write 217 # this test as isinstance(class_annotations[name], Callable) but 218 # isinstance on typing things doesn't seem to work: isinstance(list, Callable) 219 # is also true! 220 inferred = False 221 try: 222 if ( 223 name in class_annotations 224 and class_annotations[name] 225 != torch.nn.Module.__annotations__["forward"] 226 ): 227 ann_to_type = torch.jit.annotations.ann_to_type( 228 class_annotations[name], fake_range() 229 ) 230 attr_type = torch._C.InferredType(ann_to_type) 231 elif isinstance(item, torch.jit.Attribute): 232 ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range()) 233 attr_type = torch._C.InferredType(ann_to_type) 234 else: 235 attr_type = torch._C._jit_try_infer_type(item) 236 inferred = True 237 except RuntimeError as re: 238 raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re 239 240 return attr_type, inferred 241 242 added_names = set() 243 244 for name, item in nn_module._parameters.items(): 245 if name in user_annotated_ignored_attributes: 246 continue 247 248 assert item is None or isinstance(item, torch.Tensor) 249 attr_type, _ = infer_type(name, item) 250 # We currently have the invariant in various places in our code 251 # that parameters must be Tensors. However, the nn.Module API also 252 # allows NoneType parameters. These parameters are not returned as 253 # part of `parameters()` and its variants, but are available 254 # through direct attribute access. 255 concrete_type_builder.add_attribute(name, attr_type.type(), True, False) 256 added_names.add(name) 257 258 for name, item in nn_module._buffers.items(): 259 if name in user_annotated_ignored_attributes: 260 continue 261 262 assert item is None or isinstance(item, torch.Tensor) 263 attr_type, _ = infer_type(name, item) 264 concrete_type_builder.add_attribute(name, attr_type.type(), False, True) 265 added_names.add(name) 266 267 for name, item in nn_module._modules.items(): 268 if name in user_annotated_ignored_attributes: 269 continue 270 271 attr_type, _ = infer_type(name, item) 272 if item is None: 273 # Modules can be None. We don't have direct support for optional 274 # Modules, so the register it as an NoneType attribute instead. 275 concrete_type_builder.add_attribute(name, attr_type.type(), False, False) 276 continue 277 if attr_type.success(): 278 assert attr_type.type().is_interface_type() 279 # if the type can be inferred, it should be a module interface type 280 sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type( 281 attr_type.type() 282 ) 283 else: 284 # otherwise we get the concrete module type for item and add it to concrete_type 285 sub_concrete_type = get_module_concrete_type(item, share_types) 286 concrete_type_builder.add_module(name, sub_concrete_type) 287 288 added_names.add(name) 289 290 # populate constants_set 291 constants_set = set(getattr(nn_module, "__constants__", ())) 292 293 # Constants annotated via `Final[T]` rather than being added to `__constants__` 294 for name, ann in class_annotations.items(): 295 if torch._jit_internal.is_final(ann): 296 constants_set.add(name) 297 298 for name in constants_set: 299 if name in added_names: 300 # TODO: We should really error in this case, but its bc-breaking so 301 # we need to warn for at least one release 302 if name in nn_module._modules: 303 hint = "submodule" 304 elif name in nn_module._buffers: 305 hint = "buffer" 306 elif name in nn_module._parameters: 307 hint = "parameter" 308 else: 309 raise AssertionError( 310 "added_names must be submodule, parameter, or buffer" 311 ) 312 313 warnings.warn( 314 f"'{name}' was found in ScriptModule constants, " 315 f" but it is a non-constant {hint}. Consider removing it." 316 ) 317 continue 318 if not hasattr(nn_module, name): 319 # TODO: We should really error in this case, but its bc-breaking so 320 # we need to warn for at least one release 321 warnings.warn( 322 f"'{name}' was found in ScriptModule constants, " 323 "but was not actually set in __init__. " 324 "Consider removing it." 325 ) 326 continue 327 value = getattr(nn_module, name) 328 concrete_type_builder.add_constant( 329 name, _get_valid_constant(name, value, type(nn_module).__name__) 330 ) 331 added_names.add(name) 332 333 # populate overloads 334 overloads = getattr(nn_module, "__overloads__", {}) 335 # update with any annotated overloads 336 overloads.update( 337 get_overload_name_mapping( 338 get_overload_annotations(nn_module, ignored_properties) 339 ) 340 ) 341 for name, overloaded_names in overloads.items(): 342 concrete_type_builder.add_overload(name, overloaded_names) 343 344 for name, value in nn_module.__dict__.items(): 345 if name in ignored_attributes or name.startswith("__"): 346 # Python objects have lots of random attributes attached to them; 347 # PyTorch adds a few more. Prevent these from getting compiled. 348 continue 349 350 if name in user_annotated_ignored_attributes: 351 continue 352 353 if name in added_names: 354 # Don't re-add anything we already added 355 continue 356 357 isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket) 358 if isoverloadpacket: 359 value = value.op 360 # Handle Python function attributes 361 if inspect.isfunction(value): 362 try: 363 scripted_fn = torch.jit.script(value) 364 concrete_type_builder.add_function_attribute( 365 name, torch._C._jit_try_infer_type(scripted_fn).type(), value 366 ) 367 except Exception as e: 368 # If we fail to script the function, it isn't a hard error. 369 # Instead, we will add it to the list of attributes we failed 370 # to convert, with the compilation error. 371 hint = ( 372 "(This function exists as an attribute on the Python module, " 373 "but we failed to compile it to a TorchScript function. " 374 f"\nThe error stack is reproduced here:\n{e}" 375 ) 376 concrete_type_builder.add_failed_attribute(name, hint) 377 378 continue 379 380 # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or 381 # a call to an aten function like torch.add) 382 builtin_symbol_name = _find_builtin(value) 383 if builtin_symbol_name: 384 concrete_type_builder.add_builtin_function(name, builtin_symbol_name) 385 continue 386 387 # Handle Script function attributes 388 if isinstance(value, torch.jit.ScriptFunction): 389 concrete_type_builder.add_function_attribute( 390 name, torch._C._jit_try_infer_type(value).type(), value 391 ) 392 continue 393 394 # If we got here, this is a regular "data" attribute, add it to the concrete type 395 attr_type, inferred = infer_type(name, value) 396 if attr_type.success(): 397 concrete_type_builder.add_attribute(name, attr_type.type(), False, False) 398 else: 399 # TODO: could add more detail here. For example, what the user should do 400 # when the pytype is `list` or `NoneType` 401 inferred_msg = ( 402 "Its type was inferred; try adding a type annotation for the attribute." 403 if inferred 404 else "" 405 ) 406 additional_info = f"{attr_type.reason()}. {inferred_msg}" 407 hint = ( 408 "(This attribute exists on the Python module, " 409 f"but we failed to convert Python type: '{torch.typename(type(value))}' " 410 f"to a TorchScript type. {additional_info})" 411 ) 412 concrete_type_builder.add_failed_attribute(name, hint) 413 414 # add hooks to concrete type 415 for hook in nn_module._forward_hooks.values(): 416 concrete_type_builder.add_forward_hook(hook) 417 for pre_hook in nn_module._forward_pre_hooks.values(): 418 concrete_type_builder.add_forward_pre_hook(pre_hook) 419 420 return concrete_type_builder 421 422 423class ConcreteTypeStore: 424 type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]] 425 methods_compiled: Set[torch._C.ConcreteModuleType] 426 427 def __init__(self) -> None: 428 # Python module type => List[ConcreteModuleType)] 429 self.type_store = {} 430 # ConcreteTypes that have had their methods already compiled 431 self.methods_compiled = set() 432 433 def get_or_create_concrete_type(self, nn_module): 434 """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are re-used if possible.""" 435 concrete_type_builder = infer_concrete_type_builder(nn_module) 436 437 nn_module_type = type(nn_module) 438 if nn_module_type not in self.type_store: 439 self.type_store[nn_module_type] = [] 440 441 # Search the type store for an already-available JIT type 442 known_types = self.type_store[nn_module_type] 443 for known_type in known_types: 444 if known_type.equals(concrete_type_builder): 445 return known_type 446 447 # We didn't find anything; generate a new JIT type from this concrete type 448 concrete_type = concrete_type_builder.build() 449 self.type_store[nn_module_type].append(concrete_type) 450 return concrete_type 451 452 453concrete_type_store = ConcreteTypeStore() 454 455 456def create_methods_and_properties_from_stubs( 457 concrete_type, method_stubs, property_stubs 458): 459 method_defs = [m.def_ for m in method_stubs] 460 method_rcbs = [m.resolution_callback for m in method_stubs] 461 method_defaults = [get_default_args(m.original_method) for m in method_stubs] 462 463 property_defs = [p.def_ for p in property_stubs] 464 property_rcbs = [p.resolution_callback for p in property_stubs] 465 466 concrete_type._create_methods_and_properties( 467 property_defs, property_rcbs, method_defs, method_rcbs, method_defaults 468 ) 469 470 471def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs): 472 hook_defs = [h.def_ for h in hook_stubs] 473 hook_rcbs = [h.resolution_callback for h in hook_stubs] 474 475 pre_hook_defs = [h.def_ for h in pre_hook_stubs] 476 pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs] 477 478 concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs) 479 480 481def get_module_concrete_type(nn_module, share_types=True): 482 """ 483 Get a concrete type for nn_modules. 484 485 If share_types is True, the concrete type is fetched from concrete_type_store. 486 If it is False, a new concrete type is created without first searching concrete_type_store. 487 488 Args: 489 nn_module: The original Python nn.Module that we are creating a ScriptModule for. 490 share_types = Whether to share underlying JIT types between modules (if possible). 491 492 Returns: 493 A concrete type for nn_module. 494 """ 495 assert isinstance(nn_module, Module) 496 if isinstance(nn_module, torch.jit.ScriptModule) and hasattr( 497 nn_module, "_concrete_type" 498 ): 499 return nn_module._concrete_type 500 501 if share_types: 502 # Look into the store of cached JIT types 503 concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module) 504 else: 505 # Get a concrete type directly, without trying to re-use an existing JIT 506 # type from the type store. 507 concrete_type_builder = infer_concrete_type_builder(nn_module, share_types) 508 concrete_type_builder.set_poisoned() 509 concrete_type = concrete_type_builder.build() 510 511 return concrete_type 512 513 514def create_script_class(obj): 515 """ 516 Create and return a RecursiveScriptClass instance from a Python object. 517 518 Arguments: 519 obj: A Python object. 520 """ 521 qualified_class_name = _jit_internal._qualified_name(type(obj)) 522 rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj)) 523 # Script the type of obj if it hasn't already been scripted. 524 _compile_and_register_class(type(obj), rcb, qualified_class_name) 525 class_ty = _python_cu.get_class(qualified_class_name) 526 # Create an empty torch._C.ScriptObject with the scripted type. 527 cpp_object = torch._C._create_object_with_type(class_ty) 528 # Copy all of the attributes over to the torch._C.ScriptObject. 529 for name, value in obj.__dict__.items(): 530 cpp_object.setattr(name, value) 531 532 # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance. 533 return wrap_cpp_class(cpp_object) 534 535 536def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False): 537 """ 538 Create a new ScriptModule from an nn.Module. 539 540 Args: 541 nn_module: The original Python nn.Module that we are creating a ScriptModule for. 542 stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. 543 share_types: Whether to share underlying JIT types between modules (if possible). 544 NOTE: Only set to False this when we cannot guarantee type sharing will work 545 correctly. This only happens today for traced modules, where the same 546 module can produce different traced methods depending on the inputs. 547 is_tracing: Whether this function is called during tracing or scripting. If tracing, 548 we don't need to do AttributeTypeIsSupportedChecker because all the unsupported 549 attributes will be baked as constant in the tracing graph. In addition, 550 this check significantly slows down the traced modules when the module size is big. 551 """ 552 assert not isinstance(nn_module, torch.jit.RecursiveScriptModule) 553 check_module_initialized(nn_module) 554 concrete_type = get_module_concrete_type(nn_module, share_types) 555 if not is_tracing: 556 AttributeTypeIsSupportedChecker().check(nn_module) 557 return create_script_module_impl(nn_module, concrete_type, stubs_fn) 558 559 560def create_script_module_impl(nn_module, concrete_type, stubs_fn): 561 """ 562 Convert an nn.Module to a RecursiveScriptModule. 563 564 Args: 565 nn_module: The original Python nn.Module that we are creating a ScriptModule for. 566 concrete_type: The fully initialized ConcreteType of the module. 567 stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. 568 """ 569 cpp_module = torch._C._create_module_with_type(concrete_type.jit_type) 570 method_stubs = stubs_fn(nn_module) 571 property_stubs = get_property_stubs(nn_module) 572 hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module) 573 574 user_annotated_ignored_attributes = getattr( 575 nn_module, "__jit_ignored_attributes__", [] 576 ) 577 ignored_properties = jit_ignored_properties(nn_module) 578 579 def init_fn(script_module): 580 # Initialize the ScriptModule: 581 # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. 582 for name in concrete_type.get_attributes().keys(): 583 orig_value = getattr(nn_module, name) 584 orig_value = ( 585 orig_value.value 586 if isinstance(orig_value, torch.jit.Attribute) 587 else orig_value 588 ) 589 cpp_module.setattr(name, orig_value) 590 591 # 2. Copy the submodules from the original `nn_module` to the new ScriptModule, 592 # recursively scripting them. 593 for name, sub_concrete_type in concrete_type.get_modules(): 594 orig_value = getattr(nn_module, name) 595 assert isinstance( 596 orig_value, Module 597 ), f"Expected Module but got {type(orig_value)}" 598 module_type = sub_concrete_type.jit_type 599 if isinstance(module_type, torch._C.InterfaceType): 600 # use the interface inference rule to compile the module 601 scripted = interface_script(module_type, orig_value) 602 elif isinstance(orig_value, torch.jit.ScriptModule): 603 scripted = orig_value 604 else: 605 # always reuse the provided stubs_fn to infer the methods to compile 606 scripted = create_script_module_impl( 607 orig_value, sub_concrete_type, stubs_fn 608 ) 609 610 cpp_module.setattr(name, scripted) 611 script_module._modules[name] = scripted 612 613 # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. 614 # This ensures we can access these Python methods on the ScriptModule. 615 for name in dir(nn_module): 616 if name in ignored_properties: 617 continue 618 item = getattr(nn_module, name, None) 619 if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): 620 unbound_function = getattr(nn_module, name).__func__ 621 bound_method = unbound_function.__get__(script_module) 622 setattr(script_module, name, bound_method) 623 elif concrete_type.is_ignored_attribute(name): 624 setattr(script_module, name, item) 625 626 # For convenience, attach the concrete type to the new ScriptModule 627 script_module._concrete_type = concrete_type 628 629 # Actually create the ScriptModule, initializing it with the function we just defined 630 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) 631 632 # Compile methods if necessary 633 if concrete_type not in concrete_type_store.methods_compiled: 634 create_methods_and_properties_from_stubs( 635 concrete_type, method_stubs, property_stubs 636 ) 637 # Create hooks after methods to ensure no name collisions between hooks and methods. 638 # If done before, hooks can overshadow methods that aren't exported. 639 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) 640 torch._C._run_emit_module_hook(cpp_module) 641 concrete_type_store.methods_compiled.add(concrete_type) 642 643 # Copy the forward hooks and pre-hooks to the new ScriptModule 644 # to allow the hooks to be run from eager as ScriptFunctions 645 for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): 646 script_module._forward_pre_hooks[idx] = fn 647 for idx, fn in enumerate(script_module._c._get_forward_hooks()): 648 script_module._forward_hooks[idx] = fn 649 650 # Special handling so methods like __len__ work in script methods on classes derived from containers 651 if ( 652 isinstance( 653 nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) 654 ) 655 and "__len__" not in cpp_module._method_names() 656 ): 657 script_module.define(f"def __len__(self):\n return {len(nn_module)}\n") 658 if ( 659 isinstance(nn_module, torch.nn.ModuleDict) 660 and "__contains__" not in cpp_module._method_names() 661 ): 662 if len(nn_module.keys()): 663 keys = repr(list(nn_module.keys())) 664 script_module.define( 665 f"def __contains__(self, key: str):\n return key in {keys}\n" 666 ) 667 else: 668 script_module.define("def __contains__(self, key: str):\n return False\n") 669 670 # Make the compiled methods available to the Python ScriptModule class. 671 for method_stub in method_stubs: 672 if method_stub.original_method is None: 673 # define()'d methods don't have an Python original_method, so we 674 # don't need to do any Python re-wrapping stuff 675 continue 676 677 name = method_stub.original_method.__name__ 678 if name != method_stub.def_.name().name: 679 # TODO: Why skip this? Because @torch.jit._overload_method will 680 # mangle the name of the function. 681 continue 682 script_method = cpp_module._get_method(name) 683 684 # Wrap the original to propagate docstrings and such. 685 # TODO: we don't currently do this functions that are recursively 686 # compiled, we should. 687 wrapped_script_method = functools.wraps(method_stub.original_method)( 688 script_method 689 ) 690 691 # Add the methods to the script_module directly. This ensures they will 692 # be found first when `name` is looked up (as opposed to the stubs or 693 # nn.Module.forward) 694 script_module.__dict__[name] = wrapped_script_method 695 696 # Make module properties available on the Python ScriptModule class. 697 for property_stub in property_stubs: 698 property_name = property_stub.def_.name().name 699 fget = cpp_module._get_method(property_stub.def_.getter_name().name) 700 # Setter is optional, so it may not exist. 701 setter_name = property_stub.def_.setter_name() 702 fset = cpp_module._get_method(setter_name.name) if setter_name else None 703 script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore[arg-type] 704 705 # copy over python methods to script module if they aren't defined on the script module 706 # this is currently an internal api used only on module containers 707 for name in dir(nn_module): 708 if name in ignored_properties: 709 continue 710 item = getattr(nn_module, name, None) 711 if ( 712 _jit_internal.get_torchscript_modifier(item) 713 is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER 714 ): 715 add_python_attr_to_scripted_model(script_module, nn_module, name) 716 717 return script_module 718 719 720# We define shims of certain attributes on the RecursiveScriptModule to support 721# magic methods. To check if a script model defines an attribute we need 722# to also check that the attribute is not the shim 723def script_model_defines_attr(script_model, attr): 724 script_attr = getattr(script_model, attr, None) 725 if script_attr is None: 726 return False 727 default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None) 728 if default_attr is None: 729 return False 730 return script_attr != default_attr 731 732 733def add_python_attr_to_scripted_model(script_model, orig, attr): 734 if hasattr(orig, attr) and script_model_defines_attr(script_model, attr): 735 setattr(script_model, attr, getattr(orig, attr)) 736 737 738def get_overload_annotations(mod, jit_ignored_properties): 739 # original function => [(mangled overload name, overload function)] 740 overloads = {} 741 742 for name in dir(type(mod)): 743 if name in jit_ignored_properties: 744 continue 745 item = getattr(mod, name, None) 746 if not callable(item): 747 continue 748 749 # builtin functions like repr() in python 2 do not have __module__ defined 750 if hasattr(item, "__module__") and item.__module__ is not None: 751 method_overloads = _jit_internal._get_overloaded_methods( 752 item, mod.__class__ 753 ) 754 if method_overloads is None: 755 continue 756 757 if item.__func__ in method_overloads: 758 raise RuntimeError( 759 _jit_internal.get_overload_no_implementation_error_message( 760 "method", item.__func__ 761 ) 762 ) 763 764 names = [name + "__" + str(i) for i in range(len(method_overloads))] 765 overloads[item] = list(zip(names, method_overloads)) 766 767 return overloads 768 769 770def get_overload_name_mapping(overload_info): 771 # Same format as __overloads__ 772 # original function => [overload names] 773 overload_name_mappings: Dict[str, List[str]] = {} 774 for orig_fn, overloads in overload_info.items(): 775 original_name = orig_fn.__name__ 776 if original_name not in overload_name_mappings: 777 overload_name_mappings[original_name] = [] 778 779 for overload_name, _ in overloads: 780 overload_name_mappings[original_name].append(overload_name) 781 return overload_name_mappings 782 783 784def _check_no_signature(func): 785 signature = torch.jit.annotations.get_signature( 786 func, None, fake_range(), inspect.ismethod(func) 787 ) 788 if signature is None: 789 qual_name = _jit_internal._qualified_name(func) 790 raise RuntimeError( 791 f"Must explicitly add type annotations to overloaded functions: {qual_name}" 792 ) 793 794 795def make_stubs_for_overloads(overload_info): 796 overload_stubs = [] 797 for orig_fn, overloads in overload_info.items(): 798 orig_ast = get_jit_def( 799 orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule" 800 ) 801 for overload_name, overload_fn in overloads: 802 _check_no_signature(overload_fn) 803 over_ast = get_jit_def( 804 overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule" 805 ) 806 new_ast = torch._C._replace_overloaded_method_decl( 807 over_ast.decl(), orig_ast, overload_name 808 ) 809 _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) 810 overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn)) 811 return overload_stubs 812 813 814def check_module_initialized(mod): 815 assert isinstance(mod, torch.nn.Module) 816 if not hasattr(mod, "_parameters"): 817 raise RuntimeError( 818 f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?" 819 ) 820 821 # This is to avoid importing torch.distributed.nn 822 if not hasattr(mod, "remote_parameters"): 823 for name, param in mod._parameters.items(): 824 if param is not None and torch.nn.parameter.is_lazy(param): 825 raise RuntimeError( 826 f"'{torch.typename(type(mod))}' has uninitialized parameters {name}. Did you forget to run a forward pass?" 827 ) 828 for name, buf in mod._buffers.items(): 829 if buf is not None and torch.nn.parameter.is_lazy(buf): 830 raise RuntimeError( 831 f"'{torch.typename(type(mod))}' has uninitialized buffers {name}. Did you forget to run a forward pass?" 832 ) 833 834 835def infer_methods_to_compile(nn_module): 836 """Implement the default rules for which methods should act as starting points for compilation. 837 838 (TODO add a link when the rules are published). 839 """ 840 check_module_initialized(nn_module) 841 user_annotated_ignored_attributes = getattr( 842 nn_module, "__jit_ignored_attributes__", [] 843 ) 844 ignored_properties = jit_ignored_properties(nn_module) 845 846 methods: List[str] = [] 847 if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn( 848 nn_module.forward 849 ): 850 forward_func = getattr(nn_module.forward, "__func__", None) 851 module_forward = getattr(torch.nn.Module, "forward", None) 852 if forward_func != module_forward: 853 methods = ["forward"] 854 855 exported = [] 856 for name in dir(nn_module): 857 if name in ignored_properties: 858 continue 859 item = getattr(nn_module, name, None) 860 if ( 861 _jit_internal.get_torchscript_modifier(item) 862 is _jit_internal.FunctionModifiers.EXPORT 863 ): 864 exported.append(name) 865 866 methods = methods + exported 867 868 overload_name_mappings = dict(getattr(nn_module, "__overloads__", {})) 869 overload_info = get_overload_annotations(nn_module, ignored_properties) 870 overload_name_mappings.update(get_overload_name_mapping(overload_info)) 871 overload_stubs = make_stubs_for_overloads(overload_info) 872 873 nn_module.__overloads__ = overload_name_mappings 874 875 # we shouldn't directly compile overloaded methods, just its overloads 876 def ignore_overloaded(method_name): 877 return method_name not in overload_name_mappings 878 879 filtered_methods = filter(ignore_overloaded, methods) 880 881 # Unique the methods. We don't want to use a set to store the methods because it 882 # introduces non-determinism to compile order. 883 uniquer: Set[str] = set() 884 uniqued_methods = [] 885 for name in filtered_methods: 886 if name in uniquer: 887 continue 888 uniqued_methods.append(name) 889 uniquer.add(name) 890 891 stubs = [] 892 for method in uniqued_methods: 893 stubs.append(make_stub_from_method(nn_module, method)) 894 return overload_stubs + stubs 895 896 897def get_hook_stubs(nn_module): 898 """Return forward hook and pre_hook ScriptModuleStubs.""" 899 check_module_initialized(nn_module) 900 hook_map: Dict = {} 901 902 hook_stubs = [] 903 for hook in nn_module._forward_hooks.values(): 904 if hook.__name__ in hook_map: 905 if id(hook) != id(hook_map[hook.__name__]): 906 raise RuntimeError( 907 f"Hook '{hook.__name__}' on {type(nn_module).__name__} " 908 "has at least two different python definitions." 909 " Please use unique names for all hooks." 910 ) 911 else: 912 hook_map[hook.__name__] = hook 913 hook_stubs.append(make_stub(hook, hook.__name__)) 914 915 pre_hook_stubs = [] 916 for pre_hook in nn_module._forward_pre_hooks.values(): 917 if pre_hook.__name__ in hook_map: 918 if id(pre_hook) != id(hook_map[pre_hook.__name__]): 919 raise RuntimeError( 920 f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} " 921 "has at least two different python definitions." 922 " Please use unique names for all hooks." 923 ) 924 else: 925 hook_map[pre_hook.__name__] = pre_hook 926 pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__)) 927 928 return hook_stubs, pre_hook_stubs 929 930 931def get_property_stubs(nn_module): 932 """Create property stubs for the properties of the module by creating method stubs for the getter and setter.""" 933 module_ty = type(nn_module) 934 properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule") 935 rcbs = {} 936 937 for name in dir(module_ty): 938 item = getattr(module_ty, name, None) 939 if isinstance(item, property): 940 if not item.fget: 941 raise RuntimeError( 942 f"Property {name} of {nn_module.__name__} must have a getter" 943 ) 944 945 rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget) 946 947 stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts] 948 return stubs 949 950 951def interface_script(mod_interface, nn_module): 952 """ 953 Make a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile. 954 955 Args: 956 mod_interface: the interface type that the module have 957 nn_module: The original Python nn.Module that we are creating a ScriptModule for. 958 """ 959 if isinstance(nn_module, torch.jit.ScriptModule): 960 return nn_module 961 962 check_module_initialized(nn_module) 963 964 def infer_interface_methods_to_compile(nn_module): 965 """Rule to infer the methods from the interface type. 966 967 It is used to know which methods need to act as starting points for compilation. 968 """ 969 stubs = [] 970 for method in mod_interface.getMethodNames(): 971 stubs.append(make_stub_from_method(nn_module, method)) 972 return stubs 973 974 return create_script_module(nn_module, infer_interface_methods_to_compile) 975 976 977def try_compile_fn(fn, loc): 978 if _jit_internal.is_ignored_fn(fn): 979 # Don't do anything for @ignore'd functions 980 return None 981 982 if isinstance(fn, torch.nn.Module): 983 # Since modules are callable pybind recognizes them as functions, but 984 # don't do anything for them 985 return None 986 987 if not inspect.isfunction(fn) and not inspect.ismethod(fn): 988 raise RuntimeError( 989 f"`{fn}` is not a function. Recursive scripting only supports " 990 "Python functions or methods currently.\n" 991 f"Consider manually annotating `{fn}` with @torch.jit.script." 992 ) 993 994 # The object returned by __prepare_scriptable__ might have a different closure. 995 # Resolve it here to get the right resolution callback. 996 fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator] 997 998 # We don't have the actual scope where the function was defined, but we can 999 # extract the necessary info from the closed over variables on the function 1000 # object 1001 rcb = _jit_internal.createResolutionCallbackFromClosure(fn) 1002 return torch.jit.script(fn, _rcb=rcb) 1003 1004 1005def wrap_cpp_class(cpp_class): 1006 """Wrap this torch._C.Object in a Python RecursiveScriptClass.""" 1007 return torch.jit.RecursiveScriptClass(cpp_class) 1008 1009 1010def wrap_cpp_module(cpp_module): 1011 """Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules.""" 1012 1013 def init_fn(script_module): 1014 for name, cpp_module in torch._C.ModuleDict(script_module._c).items(): 1015 setattr(script_module, name, wrap_cpp_module(cpp_module)) 1016 script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type( 1017 script_module._c._type() 1018 ) 1019 1020 for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): 1021 script_module._forward_pre_hooks[idx] = fn 1022 for idx, fn in enumerate(script_module._c._get_forward_hooks()): 1023 script_module._forward_hooks[idx] = fn 1024 1025 return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) 1026 1027 1028def compile_unbound_method(concrete_type, fn): 1029 if _jit_internal.is_ignored_fn(fn): 1030 return None 1031 stub = make_stub(fn, fn.__name__) 1032 with torch._jit_internal._disable_emit_hooks(): 1033 # We don't want to call the hooks here since the graph that is calling 1034 # this function is not yet complete 1035 create_methods_and_properties_from_stubs(concrete_type, (stub,), ()) 1036 return stub 1037 1038 1039def lazy_bind(concrete_type, unbound_method): 1040 """ 1041 Return a function that lazily binds `unbound_method` to a provided Module IValue, then invokes the method. 1042 1043 We do this so that any Python shenanigans that 1044 will poison type sharing are impossible at compile time. 1045 """ 1046 1047 def lazy_binding_method(cpp_module, *args): 1048 def init_fn(script_module): 1049 orig_class = concrete_type.py_class 1050 1051 # Copy @ignored/@unused methods from the original module to the new one. 1052 # This ensures they are available during execution. 1053 for name in dir(orig_class): 1054 item = getattr(orig_class, name, None) 1055 if _jit_internal.is_ignored_fn(item): 1056 setattr(script_module, name, item) 1057 1058 # Copy constants over so they are available during execution. 1059 for name, value in concrete_type.get_constants().items(): 1060 setattr(script_module, name, value) 1061 1062 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) 1063 method = types.MethodType(unbound_method, script_module) 1064 return method(*args) 1065 1066 # make the lazy binding method "look like" the original method 1067 lazy_binding_method.original_fn = unbound_method # type: ignore[attr-defined] 1068 lazy_binding_method.__name__ = unbound_method.__name__ 1069 torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method) 1070 1071 return lazy_binding_method 1072