1"""TorchScript. 2 3This module contains functionality to support the JIT's scripting frontend, notably: 4 - torch.jit.script 5 6This is not intended to be imported directly; please use the exposed 7functionalities in `torch.jit`. 8""" 9import collections 10import copy 11import enum 12import functools 13import inspect 14import pickle 15import warnings 16from typing import Any, Callable, Dict, List, Set, Tuple, Union 17 18import torch 19import torch._jit_internal as _jit_internal 20from torch._classes import classes 21from torch._jit_internal import _get_model_id, _qualified_name 22from torch._utils_internal import log_torchscript_usage 23from torch.jit._builtins import _register_builtin 24from torch.jit._fuser import _graph_for, _script_method_graph_for 25from torch.jit._monkeytype_config import ( 26 JitTypeTraceConfig, 27 JitTypeTraceStore, 28 monkeytype_trace, 29) 30from torch.jit._recursive import ( 31 _compile_and_register_class, 32 infer_methods_to_compile, 33 ScriptMethodStub, 34 wrap_cpp_module, 35) 36from torch.jit._state import ( 37 _enabled, 38 _set_jit_function_cache, 39 _set_jit_overload_cache, 40 _try_get_jit_cached_function, 41 _try_get_jit_cached_overloads, 42) 43from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def 44from torch.nn import Module 45from torch.overrides import ( 46 has_torch_function, 47 has_torch_function_unary, 48 has_torch_function_variadic, 49) 50from torch.package import PackageExporter, PackageImporter 51from torch.utils import set_module 52 53from ._serialization import validate_map_location 54 55 56type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType 57 58torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined] 59torch._C.ScriptFunction.graph_for = _graph_for # type: ignore[attr-defined] 60ScriptFunction = torch._C.ScriptFunction 61ScriptFunction.__doc__ = """ 62Functionally equivalent to a :class:`ScriptModule`, but represents a single 63function and does not have any attributes or Parameters. 64""" 65set_module(ScriptFunction, "torch.jit") 66 67 68# Throws an error if a jit function is pickled. 69# Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument. 70def _reduce(cls): 71 raise pickle.PickleError("ScriptFunction cannot be pickled") 72 73 74ScriptFunction.__reduce__ = _reduce # type: ignore[assignment] 75 76 77if _enabled: 78 Attribute = collections.namedtuple("Attribute", ["value", "type"]) 79else: 80 81 def Attribute(value, type): # type: ignore[no-redef] 82 return value 83 84 85Attribute.__doc__ = """ 86 This method is a pass-through function that returns `value`, mostly 87 used to indicate to the TorchScript compiler that the left-hand side 88 expression is a class instance attribute with type of `type`. Note that 89 `torch.jit.Attribute` should only be used in `__init__` method of `jit.ScriptModule` 90 subclasses. 91 92 Though TorchScript can infer correct type for most Python expressions, there are some cases where 93 type inference can be wrong, including: 94 95 - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor` 96 - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume 97 it is type `T` rather than `Optional[T]` 98 99 In eager mode, it is simply a pass-through function that returns `value` 100 without other implications. 101 102 Example: 103 104 .. testcode:: 105 106 import torch 107 from typing import Dict 108 109 class AttributeModule(torch.jit.ScriptModule): 110 def __init__(self) -> None: 111 super().__init__() 112 self.foo = torch.jit.Attribute(0.1, float) 113 114 # we should be able to use self.foo as a float here 115 assert 0.0 < self.foo 116 117 self.names_ages = torch.jit.Attribute({}, Dict[str, int]) 118 self.names_ages["someone"] = 20 119 assert isinstance(self.names_ages["someone"], int) 120 121 m = AttributeModule() 122 # m will contain two attributes 123 # 1. foo of type float 124 # 2. names_ages of type Dict[str, int] 125 126 .. testcleanup:: 127 128 del AttributeModule 129 del m 130 131 Note: it's now preferred to instead use type annotations instead of `torch.jit.Attribute`: 132 133 .. testcode:: 134 135 import torch 136 from typing import Dict 137 138 class AttributeModule(torch.nn.Module): 139 names: Dict[str, int] 140 141 def __init__(self) -> None: 142 super().__init__() 143 self.names = {} 144 145 m = AttributeModule() 146 147 .. testcleanup:: 148 149 del AttributeModule 150 del m 151 152 Args: 153 value: An initial value to be assigned to attribute. 154 type: A Python type 155 156 Returns: 157 Returns `value` 158""" 159 160 161def _get_type_trace_db(): 162 # This is a private API. Use of this for external purposes is discouraged. 163 return type_trace_db 164 165 166# Gets a function from the name of a method on a type 167def _get_function_from_type(cls, name): 168 return getattr(cls, name, None) 169 170 171# ScriptClasses must be new-style classes because we construct them using their 172# __new__ method. 173def _is_new_style_class(cls): 174 if hasattr(cls, "__class__"): 175 return "__dict__" in dir(cls) or hasattr(cls, "__slots__") 176 177 178# These OrderedDictWrapper classes replace the actual OrderedDicts in 179# module with versions that get/set properties inside of Module. 180# This allows us to reuse most of nn.Module while still storing the 181# data in C++. 182# Each OrderedDict needs to support: 183# x not in view 184# x in view 185# view[name] = ... 186# view.values() 187# del view[name] 188# view.items() 189# view.keys() 190# len(view) 191 192 193class OrderedDictWrapper: 194 def __init__(self, _c): 195 self._c = _c 196 197 def keys(self): 198 return [k for k, v in self.items()] 199 200 def values(self): 201 return [v for k, v in self.items()] 202 203 def __len__(self): 204 return len(self.values()) 205 206 def __delitem__(self, k): 207 raise RuntimeError("cannot delete methods or parameters of a script module") 208 209 def items(self): 210 return self._c.items() 211 212 def __setitem__(self, k, v): 213 if k not in self: 214 raise RuntimeError( 215 f"Can't add a new parameter after ScriptModule construction. Tried to add '{k}" 216 ) 217 self._c.setattr(k, v) 218 219 def __contains__(self, k): 220 return self._c.contains(k) 221 222 def __getitem__(self, k): 223 if k not in self: 224 raise KeyError(k) 225 return self._c.getattr(k) 226 227 228class OrderedModuleDict(OrderedDictWrapper): 229 def __init__(self, module, python_dict): 230 super().__init__(torch._C.ModuleDict(module)) 231 # contains _both_ script modules and non-script python-only modules 232 233 # because script modules are subclassed in python and the 234 # C++ Module class will not hold references to them, 235 # to ensure that you always get the same python value here 236 # we store it in the python dict as well 237 self._python_modules = python_dict 238 239 def items(self): 240 r = self._python_modules.items() 241 return r 242 243 def __contains__(self, k): 244 return k in self._python_modules 245 246 def __setitem__(self, k, v): 247 # Cases where sub-module can be re-assigned after ScriptModule construction 248 # 1. If the attr is an module interface type, it's guaranteed that the module is 249 # not inlined in the graph, so it's safe to swap a new ScriptModule in. 250 # 2. if the new value if a ScriptModule with the same JIT type, IR won't change 251 # and it's legit to swap a new module in. 252 # In these two cases we allow swapping a new scripted module and update the 253 # corresponding python module dict to keep sync. 254 # Note: the value to be swapped in has to be ScriptModule instead of nn.Module, 255 # otherwise it's illegal and we throw error. 256 if isinstance(v, ScriptModule): 257 self._c.setattr(k, v) 258 self._python_modules[k] = v 259 else: 260 raise RuntimeError( 261 "Cannot re-assign modules in a ScriptModule with non-scripted " 262 f"module, tried to replace existing module '{k}': {v}" 263 ) 264 265 def __getitem__(self, k): 266 return self._python_modules[k] 267 268 269# For each user-defined class that subclasses ScriptModule, this meta-class: 270# (1) finds all the methods annotated with @script_method in a ScriptModule and 271# removes them from the class attributes 272# (2) puts a wrapper around the class's __init__ method to recursively compile 273# all of the script_methods with the module after the original __init__ has 274# run. This has to occur after the user-defined __init__ so that submodules and 275# parameters are initialized _before_ the script compiler resolve references to 276# `self.param` or `self.module`. 277class ScriptMeta(type): 278 def __init__(cls, name, bases, attrs): # noqa: B902 279 # Aggregate all the ScriptMethods and constants from superclasses 280 cls._methods: Dict[str, Any] = {} 281 cls._constants_set = set(getattr(cls, "__constants__", ())) 282 for base in reversed(bases): 283 for k, v in getattr(base, "_methods", {}).items(): 284 cls._methods[k] = v 285 base_constants: Set = getattr(base, "_constants_set", set()) 286 cls._constants_set = cls._constants_set.union(base_constants) 287 288 # find all the script methods of the current class 289 for k, v in sorted(attrs.items()): 290 if isinstance(v, ScriptMethodStub): 291 delattr(cls, k) 292 cls._methods[v.original_method.__name__] = v 293 294 if getattr(cls, "_disable_script_meta", False): 295 # We leave built-in ScriptModule types alone, since this metaclass 296 # is only for compiling user classes that inherit from 297 # ScriptModule. 298 super().__init__(name, bases, attrs) 299 return 300 301 original_init = getattr(cls, "__init__", lambda self: None) 302 303 @functools.wraps(original_init) 304 def init_then_script(self, *args, **kwargs): 305 num_methods = len(cls._methods) 306 original_init(self, *args, **kwargs) 307 added_methods_in_init = len(cls._methods) > num_methods 308 309 if type(self) == cls: 310 311 def make_stubs(module): 312 cls = type(module) 313 if hasattr(cls, "_methods"): 314 return [v for k, v in sorted(cls._methods.items())] 315 else: 316 return infer_methods_to_compile(module) 317 318 self.__dict__[ 319 "_actual_script_module" 320 ] = torch.jit._recursive.create_script_module( 321 self, make_stubs, share_types=not added_methods_in_init 322 ) 323 324 # Delete the Python attributes that now shadow the ScriptModule 325 # ones, so that __getattr__ and __setattr__ will properly find 326 # the scripted versions. 327 concrete_type = self._actual_script_module._concrete_type 328 for name in concrete_type.get_attributes(): 329 delattr(self, name) 330 for name, _ in concrete_type.get_modules(): 331 delattr(self, name) 332 for name in ("_parameters", "_buffers", "_modules"): 333 delattr(self, name) 334 335 cls.__init__ = init_then_script # type: ignore[misc] 336 super().__init__(name, bases, attrs) 337 338 339class _CachedForward: 340 def __get__(self, obj, cls): 341 return self.__getattr__("forward") # type: ignore[attr-defined] 342 343 344class ScriptWarning(Warning): 345 pass 346 347 348def script_method(fn): 349 if not _enabled: 350 return fn 351 # NOTE: we need to traverse two frames here because the meta-class frame 352 # for ScriptModule will be present, as opposed to invoking @script on a 353 # a function or invoking define() on a CompilationUnit. 354 # The stack will look like: 355 # 356 # 0. createResolutionCallback() 357 # 1. script_method() 358 # 2. ScriptModule metaclass frame 359 # 3. Surrounding scope 360 # 361 # createResolutionCallback internally adds 1 to get us to the scope of this 362 # function (the calling function). Adding 2 gets us to the proper surrounding scope. 363 _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) 364 ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule") 365 return ScriptMethodStub(_rcb, ast, fn) 366 367 368class ConstMap: 369 def __init__(self, const_mapping): 370 self.const_mapping = const_mapping 371 372 def __getattr__(self, attr): 373 return self.const_mapping[attr] 374 375 376def unpackage_script_module( 377 importer: PackageImporter, script_module_id: str 378) -> torch.nn.Module: 379 """ 380 Call by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function. 381 382 Performs work of loading and returning a ScriptModule from a ``torch.package`` archive. 383 """ 384 if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader): 385 raise RuntimeError( 386 "Loading ScriptObjects from a PackageImporter created from a " 387 "directory is not supported. Use a package archive file instead." 388 ) 389 cu = torch._C.CompilationUnit() 390 cpp_module = torch._C._import_ir_module_from_package( 391 cu, 392 importer.zip_reader, 393 importer.storage_context, 394 validate_map_location(importer.last_map_location), 395 script_module_id, 396 ) 397 return wrap_cpp_module(cpp_module) 398 399 400if _enabled: 401 _magic_methods = [ 402 "__iter__", 403 "__len__", 404 "__neg__", 405 "__mul__", 406 "__contains__", 407 "__add__", 408 "__sub__", 409 "__pow__", 410 "__truediv__", 411 "__mod__", 412 "__ne__", 413 "__eq__", 414 "__lt__", 415 "__gt__", 416 "__le__", 417 "__ge__", 418 "__and__", 419 "__or__", 420 "__xor__", 421 "__getitem__", 422 "__setitem__", 423 "__call__", 424 "__int__", 425 "__float__", 426 "__bool__", 427 "__str__", 428 "__enter__", 429 "__exit__", 430 ] 431 432 class RecursiveScriptClass: 433 """Wrapper for a TorchScript class instance for use in Python. 434 435 An analogue of RecursiveScriptModule for regular objects that are not modules. 436 This class is a wrapper around a torch._C.ScriptObject that represents an instance 437 of a TorchScript class and allows it to be used in Python. 438 439 Attributes: 440 _c [torch._C.ScriptObject]: The C++ object to which attribute lookups and method 441 calls are forwarded. 442 _props [Dict[str, property]]: A dictionary of properties fetched from self._c and 443 exposed on this wrppaer. 444 """ 445 446 def __init__(self, cpp_class): 447 super().__init__() 448 self.__dict__["_initializing"] = True 449 self._c = cpp_class 450 451 # Add wrapped object's properties to this class instance. 452 self._props = { 453 prop.name: property(prop.getter, prop.setter) 454 for prop in self._c._properties() 455 } 456 457 self.__dict__["_initializing"] = False 458 459 def __getattr__(self, attr): 460 if self.__dict__.get("_initializing"): 461 return super().__getattr__(attr) # type: ignore[misc] 462 463 if attr in self._props: 464 return self._props[attr].fget() # type: ignore[call-arg, misc] 465 466 return getattr(self._c, attr) 467 468 def __setattr__(self, attr, value): 469 if self.__dict__.get("_initializing"): 470 return super().__setattr__(attr, value) 471 472 if attr in self._props: 473 return self._props[attr].fset(value) # type: ignore[call-arg, misc] 474 475 setattr(self._c, attr, value) 476 477 # Delegate calls to magic methods like __len__ to the C++ module backing the 478 # RecursiveScriptClass. 479 def forward_magic_method(self, method_name, *args, **kwargs): 480 if not self._c._has_method(method_name): 481 raise TypeError 482 483 self_method = self.__getattr__(method_name) 484 return self_method(*args, **kwargs) 485 486 def __getstate__(self): 487 raise pickle.PickleError("ScriptClasses cannot be pickled") 488 489 def __iadd__(self, other): 490 if self._c._has_method("__iadd__"): 491 return self.forward_magic_method("__iadd__", other) 492 else: 493 return self.forward_magic_method("__add__", other) 494 495 for method_name in _magic_methods: 496 497 def method_template(self, *args, **kwargs): 498 return self.forward_magic_method(method_name, *args, **kwargs) 499 500 setattr(RecursiveScriptClass, method_name, method_template) 501 502 # this is a Python 'non-data descriptor' that causes the first access 503 # to ScriptModule's forward to look up the forward method and stash 504 # it in the objects dict. Due to the standard rules for attribute lookup, 505 # subsequent lookups will just directly return the previously looked up method. 506 # This is necessary because nn.Module defines forward as a method. If we 507 # did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward 508 # which always throws an exception. 509 510 class ScriptModule(Module, metaclass=ScriptMeta): 511 r"""Wrapper for C++ torch::jit::Module with methods, attributes, and parameters. 512 513 A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s 514 contain methods, attributes, parameters, and 515 constants. These can be accessed the same way as on a normal ``nn.Module``. 516 """ 517 518 __jit_unused_properties__ = [ 519 "code", 520 "code_with_constants", 521 "graph", 522 "inlined_graph", 523 "original_name", 524 ] 525 526 def __init__(self) -> None: 527 super().__init__() 528 529 forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment] 530 531 def __getattr__(self, attr): 532 if "_actual_script_module" not in self.__dict__: 533 return super().__getattr__(attr) 534 return getattr(self._actual_script_module, attr) 535 536 def __setattr__(self, attr, value): 537 if "_actual_script_module" not in self.__dict__: 538 # Unwrap torch.jit.Attribute into a regular setattr + record 539 # the provided type in __annotations__. 540 # 541 # This ensures that if we use the attr again in `__init__`, it 542 # will look like the actual value, not an instance of Attribute. 543 if isinstance(value, Attribute): 544 # NB: Ensure that we set __annotations__ on the specific 545 # class in question, and not on a superclass (which would 546 # be wrong wrong wrong!). 547 # See also https://github.com/pytorch/pytorch/issues/39463 548 if "__annotations__" not in self.__class__.__dict__: 549 self.__class__.__annotations__ = {} 550 self.__annotations__[attr] = value.type 551 value = value.value 552 return super().__setattr__(attr, value) 553 554 setattr(self._actual_script_module, attr, value) 555 556 def define(self, src): 557 if "_actual_script_module" in self.__dict__: 558 # If we have completed initialization, just defer to the 559 # backing RecursiveScriptModule to eagerly compile the provided 560 # source. 561 return self._actual_script_module.define(src) 562 563 # Otherwise, we are still in the object's __init__. 564 # In that case, add `src` as a stub to be compiled. 565 # 566 # We use frames_up=1 to get to the proper surrounding scope. The stack 567 # will look like: 568 # 0. createResolutionCallback 569 # 1. define() 570 # 2. surrounding scope. 571 # 572 # createResolutionCallback internally adds 1 to get us to our frame, then 573 # we add 1 to get to the proper surrounding scope. 574 rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) 575 ast = torch._C._parse_source_def(src) 576 self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None) 577 578 def _replicate_for_data_parallel(self): 579 return self._actual_script_module._replicate_for_data_parallel() 580 581 def __reduce_package__(self, exporter: PackageExporter): 582 """Save a ScriptModule inside of a ``torch.package`` archive. 583 584 Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when 585 saving TorchScript objects. Performs act of saving a ScriptModule inside of 586 a ``torch.package`` archive. 587 588 Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s 589 Pickler's ``persistent_load`` function. 590 """ 591 script_module_id = exporter.get_unique_id() 592 exporter.script_module_serializer.serialize(self._c, int(script_module_id)) 593 return (unpackage_script_module, (script_module_id,)) 594 595 class RecursiveScriptModule(ScriptModule): 596 # XXX: RecursiveScriptModule inherits from ScriptModule for the sole 597 # reason that it retains the existing isinstance(ScriptModule) 598 # behavior. 599 r"""Retain the existing isinstance(ScriptModule) behavior. 600 601 The core data structure in TorchScript is the ``ScriptModule``. It is an 602 analogue of torch's ``nn.Module`` and represents an entire model as a tree of 603 submodules. Like normal modules, each individual module in a ``ScriptModule`` can 604 have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented 605 as Python functions, but in ``ScriptModule``\s methods are implemented as 606 TorchScript functions, a statically-typed subset of Python that contains all 607 of PyTorch's built-in Tensor operations. This difference allows your 608 ``ScriptModule``\s code to run without the need for a Python interpreter. 609 610 ``ScriptModule``\s should not be created manually, instead use 611 either :func:`tracing <torch.jit.trace>` or :func:`scripting <torch.jit.script>`. 612 Tracing and scripting can be applied incrementally and :ref:`composed as necessary <Types>`. 613 614 * Tracing records the tensor operations as executed with a set of example inputs and uses these 615 operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing, 616 but values other than Tensors and control flow aren't captured in the graph. 617 618 * Scripting inspects the Python code of the model 619 and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow. 620 Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary. 621 """ 622 623 _disable_script_meta = True 624 625 def __init__(self, cpp_module): 626 self.__dict__["_initializing"] = True 627 self._c = cpp_module 628 super().__init__() 629 # Delete the 'training' attribute set up by `Module.__init__`. It 630 # will get set on the underlying cpp module, so we delete it here 631 # to avoid this version shadowing the cpp module version. 632 delattr(self, "training") 633 634 @staticmethod 635 def _construct(cpp_module, init_fn): 636 """ 637 Construct a RecursiveScriptModule that's ready for use. 638 639 PyTorch code should use this to construct a RecursiveScriptModule instead 640 of instead of calling `__init__` directly, as it makes sure the 641 object is properly finalized (and in the future, we may take 642 control of how the RecursiveScriptModule instance is created). 643 644 Args: 645 cpp_module: The C++ Module that will hold the actual state of 646 this RecursiveScriptModule instance. 647 init_fn: Lambda that initializes the RecursiveScriptModule passed to it. 648 """ 649 script_module = RecursiveScriptModule(cpp_module) 650 init_fn(script_module) 651 652 # Finalize the ScriptModule: replace the nn.Module state with our 653 # custom implementations and flip the _initializing bit. 654 RecursiveScriptModule._finalize_scriptmodule(script_module) 655 return script_module 656 657 @staticmethod 658 def _finalize_scriptmodule(script_module): 659 script_module._parameters = OrderedDictWrapper( 660 torch._C.ParameterDict(script_module._c) 661 ) 662 script_module._buffers = OrderedDictWrapper( 663 torch._C.BufferDict(script_module._c) 664 ) 665 script_module._modules = OrderedModuleDict( 666 script_module._c, script_module._modules 667 ) 668 script_module._initializing = False 669 670 def _reconstruct(self, cpp_module): 671 """ 672 Re-construct an instance of RecursiveScriptModule using an instance of a C++ module. 673 674 Args: 675 cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around. 676 """ 677 self.__init__(cpp_module) # type: ignore[misc] 678 679 # Copy the concrete type from the C++ module to this ScriptModule. 680 self._concrete_type = torch._C.ConcreteModuleType.from_jit_type( 681 self._c._type() 682 ) 683 684 # Copy submodules from the C++ module to this ScriptModule. 685 modules = {} 686 for name, cpp_module in torch._C.ModuleDict(self._c).items(): 687 modules[name] = wrap_cpp_module(cpp_module) 688 self._modules = OrderedModuleDict(self._c, modules) # type: ignore[assignment] 689 690 # Copy parameters and buffers. 691 self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) # type: ignore[assignment] 692 self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) # type: ignore[assignment] 693 694 # Get rid of the functions from the old C++ module. 695 self.__dict__ = { 696 k: v 697 for k, v in self.__dict__.items() 698 if not isinstance(v, torch._C.ScriptMethod) 699 } 700 self.__dict__["_initializing"] = False 701 702 @property 703 def graph(self): 704 r"""Return a string representation of the internal graph for the ``forward`` method. 705 706 See :ref:`interpreting-graphs` for details. 707 """ 708 return self._c._get_method("forward").graph 709 710 @property 711 def inlined_graph(self): 712 r""" 713 Return a string representation of the internal graph for the ``forward`` method. 714 715 This graph will be preprocessed to inline all function and method calls. 716 See :ref:`interpreting-graphs` for details. 717 """ 718 return self.forward.inlined_graph # type: ignore[attr-defined] 719 720 @property 721 def code(self): 722 r""" 723 Return a pretty-printed representation (as valid Python syntax) of the internal graph for the ``forward`` method. 724 725 See :ref:`inspecting-code` for details. 726 """ 727 return self.forward.code # type: ignore[attr-defined] 728 729 @property 730 def code_with_constants(self): 731 r"""Return a tuple. 732 733 Returns a tuple of: 734 735 [0] a pretty-printed representation (as valid Python syntax) of 736 the internal graph for the ``forward`` method. See `code`. 737 [1] a ConstMap following the CONSTANT.cN format of the output in [0]. 738 The indices in the [0] output are keys to the underlying constant's values. 739 740 See :ref:`inspecting-code` for details. 741 """ 742 r = self.forward.code_with_constants # type: ignore[attr-defined] 743 return (r[0], ConstMap(r[1])) 744 745 def save(self, f, **kwargs): 746 r"""Save with a file-like object. 747 748 save(f, _extra_files={}) 749 750 See :func:`torch.jit.save <torch.jit.save>` which accepts a file-like object. 751 This function, torch.save(), converts the object to a string, treating it as a path. 752 DO NOT confuse these two functions when it comes to the 'f' parameter functionality. 753 """ 754 return self._c.save(str(f), **kwargs) 755 756 def _save_for_lite_interpreter(self, *args, **kwargs): 757 r"""Add (or update) the bytecode session to the script model. 758 759 _save_for_lite_interpreter(f) 760 761 The updated model is used 762 in lite interpreter for mobile applications. 763 764 Args: 765 f: a string containing a file name. 766 _extra_files: Map from filename to contents which will be stored as part of 'f'. 767 768 """ 769 return self._c._save_for_mobile(*args, **kwargs) 770 771 def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): 772 return self._c._save_to_buffer_for_mobile(*args, **kwargs) 773 774 def save_to_buffer(self, *args, **kwargs): 775 return self._c.save_to_buffer(*args, **kwargs) 776 777 def get_debug_state(self, *args, **kwargs): 778 return self._c.get_debug_state() 779 780 def extra_repr(self): 781 return f"original_name={self.original_name}" 782 783 def graph_for(self, *args, **kwargs): 784 return self.forward.graph_for(self, *args, **kwargs) # type: ignore[attr-defined] 785 786 @property 787 def original_name(self): 788 if type(self) == str(self._c._type().name()): 789 return "" 790 return str(self._c._type().name()) 791 792 def define(self, src): 793 # We use frames_up=1 to get to the proper surrounding scope. The stack 794 # will look like: 795 # 0. createResolutionCallback 796 # 1. define() 797 # 2. surrounding scope. 798 # 799 # createResolutionCallback internally adds 1 to get us to our frame, then 800 # we add 1 to get to the proper surrounding scope. 801 rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) 802 self._c._define(self._concrete_type, src, rcb) 803 804 def __getattr__(self, attr): 805 if "_initializing" not in self.__dict__: 806 raise RuntimeError( 807 "ScriptModule has not been initialized, did you forget to call super's init?" 808 ) 809 810 if self._initializing: 811 return super().__getattr__(attr) 812 813 # _modules check is before hasattr since modules are included as attributes in _c, 814 # but we want to get the python wrapper from _modules instead of the raw _c object. 815 if attr in self._modules: 816 return self._modules[attr] 817 elif self._c.hasattr(attr): 818 return self._c.getattr(attr) 819 elif self._c._has_method(attr): 820 script_method = self._c._get_method(attr) 821 # cache method so future calls do not go through __getattr__ 822 # to improve invocation performance 823 self.__dict__[attr] = script_method 824 return script_method 825 826 return super().__getattr__(attr) 827 828 def __setattr__(self, attr, value): 829 if self._initializing: 830 return super().__setattr__(attr, value) 831 832 if attr in self._modules: 833 self._modules[attr] = value 834 elif self._c.hasattr(attr): 835 self._c.setattr(attr, value) 836 elif ( 837 hasattr(self, "_concrete_type") 838 and attr in self._concrete_type.get_constants().keys() 839 ): 840 # TODO: we don't have _concrete_type set after load(), and in general we lose constant information. 841 # We should encode constants as class type attributes (or something) so it persists across save/load. 842 raise AttributeError( 843 f"Cannot mutate TorchScript constant value: '{attr}'. Value: '{value}'" 844 ) 845 else: 846 # We allow setting Python attributes on the ScriptModule, for 847 # when people want to stash some convenience info on it. 848 # TODO: it's possible that the following is confusing: 849 # s = torch.jit.script(...) 850 # s.python_attr = ... 851 # s.save() <--- this doesn't have `python_attr` 852 # It's fairly trivial to save enough info to warn in this case. 853 return super().__setattr__(attr, value) 854 855 def __copy__(self): 856 return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c)) 857 858 def __deepcopy__(self, memo): 859 return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo)) 860 861 # Python magic methods do method lookups on an object's class type, instead of looking up 862 # the method defines on the class instance. In order to continue to expose the magic methods 863 # of builtin-containers (ModuleList, Sequential, ModuleDict) to Python, we 864 # define magic methods here as a shim to the correct attribute. 865 def forward_magic_method(self, method_name, *args, **kwargs): 866 self_method = getattr(self, method_name) 867 if getattr(self_method, "__func__", None) == getattr( 868 RecursiveScriptModule, method_name 869 ): 870 raise NotImplementedError 871 return self_method(*args, **kwargs) 872 873 def __iter__(self): 874 return self.forward_magic_method("__iter__") 875 876 def __getitem__(self, idx): 877 return self.forward_magic_method("__getitem__", idx) 878 879 def __len__(self): 880 return self.forward_magic_method("__len__") 881 882 def __contains__(self, key): 883 return self.forward_magic_method("__contains__", key) 884 885 # dir is defined by the base nn.Module, so instead of throwing if 886 # it is not overridden, we call into the nn.Module __dir__ method 887 def __dir__(self): 888 self_method = self.__dir__ 889 if ( 890 self_method.__func__ # type: ignore[attr-defined] 891 == _get_function_from_type(RecursiveScriptModule, "__dir__") 892 ): 893 return super().__dir__() 894 return self_method() 895 896 # to resolve bool(value), Python looks if __bool__ is defined then __iter__ 897 # is defined then returns true for classes. Since __iter__() on this 898 # class throws if it isn't overridden, we define __bool__ to preserve default behavior 899 def __bool__(self): 900 self_method = self.__bool__ 901 if ( 902 self_method.__func__ # type: ignore[attr-defined] 903 == _get_function_from_type(RecursiveScriptModule, "__bool__") 904 ): 905 return True 906 return self_method() 907 908 def _replicate_for_data_parallel(self): 909 # we have to initialize ScriptModule properly so that 910 # it works with pybind11 911 def init_fn(script_module): 912 # Don't do anything here, we'll initialize the ScriptModule below 913 return 914 915 return RecursiveScriptModule._construct( 916 self._c._replicate_for_data_parallel(), init_fn 917 ) 918 919 # Need to copy all RecursiveScriptModule methods to ScriptModule. 920 # 921 # This is because `super().foo()` does not use 922 # `__getattr__` to look up `foo`. So we need to make each method available on 923 # the ScriptModule manually. 924 for name, item in RecursiveScriptModule.__dict__.items(): 925 if not callable(item) and not isinstance(item, property): 926 continue 927 if name.startswith("__") or hasattr(ScriptModule, name): 928 continue 929 # We can copy over the implementation wholesale because besides the 930 # `super()` thing above, ScriptModule behaves exactly like 931 # RecursiveScriptModule 932 setattr(ScriptModule, name, item) 933 934 def _get_methods(cls): 935 import inspect 936 937 # In Python 3 unbound methods are functions, but in Python 2 they are methods 938 return inspect.getmembers( 939 cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) 940 ) 941 942 _compiled_methods_allowlist = { 943 "forward", 944 "register_buffer", 945 "register_parameter", 946 "register_module", 947 "add_module", 948 "_apply", 949 "apply", 950 "cuda", 951 "cpu", 952 "to", 953 "type", 954 "float", 955 "double", 956 "half", 957 "state_dict", 958 "_save_to_state_dict", 959 "load_state_dict", 960 "_load_from_state_dict", 961 "_named_members", 962 "parameters", 963 "named_parameters", 964 "buffers", 965 "named_buffers", 966 "children", 967 "named_children", 968 "modules", 969 "named_modules", 970 "zero_grad", 971 "share_memory", 972 "_get_name", 973 "extra_repr", 974 "_slow_forward", 975 "_tracing_name", 976 "eval", 977 "train", 978 "get_extra_state", 979 "set_extra_state", 980 } 981 982 def _make_fail(name): 983 def fail(self, *args, **kwargs): 984 raise RuntimeError(name + " is not supported on ScriptModules") 985 986 return fail 987 988 for name, method in _get_methods(torch.nn.Module): 989 if name.startswith("__") or name.endswith("_call_impl"): 990 continue 991 if ( 992 name not in RecursiveScriptModule.__dict__ 993 and name not in _compiled_methods_allowlist 994 ): 995 setattr(RecursiveScriptModule, method.__name__, _make_fail(name)) 996 997 998else: 999 # TODO MAKE SURE THAT DISABLING WORKS 1000 class RecursiveScriptClass: # type: ignore[no-redef] 1001 pass 1002 1003 class ScriptModule(torch.nn.Module): # type: ignore[no-redef] 1004 def __init__(self, arg=None): 1005 super().__init__() 1006 1007 class RecursiveScriptModule(ScriptModule): # type: ignore[no-redef] 1008 def __init__(self, arg=None): 1009 super().__init__() 1010 1011 1012def call_prepare_scriptable_func_impl(obj, memo): 1013 if not isinstance(obj, torch.nn.Module): 1014 return obj 1015 1016 obj_id = id(obj) 1017 1018 # If obj_id is in memo, obj has already been prepared or is being 1019 # prepared in another call up the stack. 1020 if obj_id in memo: 1021 return memo[id(obj)] 1022 1023 obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator] 1024 # Record obj in memo to avoid infinite recursion in the case of cycles in the module 1025 # hierarchy when recursing below. 1026 memo[obj_id] = obj 1027 1028 new_obj_dict = {} 1029 1030 for name, sub_module in obj.__dict__.items(): 1031 if name == "_modules": 1032 for k, v in sub_module.items(): 1033 sub_module[k] = call_prepare_scriptable_func_impl(v, memo) 1034 new_obj_dict[name] = sub_module 1035 elif isinstance(sub_module, torch.nn.Module) and not isinstance( 1036 sub_module, ScriptModule 1037 ): 1038 new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo) 1039 else: 1040 new_obj_dict[name] = sub_module 1041 1042 for k, v in new_obj_dict.items(): 1043 obj.__dict__[name] = v 1044 1045 return obj 1046 1047 1048def call_prepare_scriptable_func(obj): 1049 memo: Dict[int, torch.nn.Module] = {} 1050 return call_prepare_scriptable_func_impl(obj, memo) 1051 1052 1053def create_script_dict(obj): 1054 """ 1055 Create a ``torch._C.ScriptDict`` instance with the data from ``obj``. 1056 1057 Args: 1058 obj (dict): The Python dictionary that is used to initialize the ``ScriptDict`` 1059 returned by this function. 1060 1061 Returns: 1062 An instance of ``torch._C.ScriptDict`` that has the same data as ``obj`` 1063 and can be passed between Python and TorchScript with reference semantics and 1064 zero copy overhead. 1065 """ 1066 return torch._C.ScriptDict(obj) # type: ignore[attr-defined] 1067 1068 1069def create_script_list(obj, type_hint=None): 1070 """ 1071 Create a ``torch._C.ScriptList`` instance with the data from ``obj``. 1072 1073 Args: 1074 obj (dict): The Python list that is used to initialize the ``ScriptList`` 1075 returned by this function. 1076 Returns: 1077 An instance of ``torch._C.ScriptList`` that has the same data as ``obj`` 1078 and can be passed between Python and TorchScript with reference semantics and 1079 zero copy overhead. 1080 """ 1081 return torch._C.ScriptList(obj) # type: ignore[attr-defined] 1082 1083 1084_TOPLEVEL: bool = True 1085 1086 1087def _script_impl( 1088 obj, 1089 optimize=None, 1090 _frames_up=0, 1091 _rcb=None, 1092 example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None, 1093): 1094 global type_trace_db 1095 1096 if optimize is not None: 1097 warnings.warn( 1098 "`optimize` is deprecated and has no effect. " 1099 "Use `with torch.jit.optimized_execution()` instead", 1100 FutureWarning, 1101 stacklevel=3, 1102 ) 1103 1104 # No-op for modules, functions, class instances that are already scripted 1105 if isinstance(obj, RecursiveScriptClass): 1106 return obj 1107 if isinstance(obj, ScriptModule): 1108 return obj 1109 if isinstance(obj, ScriptFunction): 1110 return obj 1111 1112 if example_inputs: 1113 # If MonkeyType is installed, enable profile directed type annotation 1114 # Check if example_inputs are defined and generate call traces 1115 # for the method by running eager mode version of the method with 1116 # the provide example inputs. This logs all the traces in type_trace_db 1117 type_trace_db = JitTypeTraceStore() 1118 if monkeytype_trace: 1119 monkeytype_config = JitTypeTraceConfig(type_trace_db) 1120 with monkeytype_trace(monkeytype_config): 1121 if isinstance(example_inputs, Dict): 1122 # If the obj is an nn.Module or a class, then each method is 1123 # executed with the arguments provided in the example inputs. 1124 # example inputs here will be of type Dict(class.method, (arguments)) 1125 # This is used to infer type annotations for those methods 1126 # which are not called directly under the hood of monkeytype. 1127 for module, example_input in example_inputs.items(): 1128 for example in example_input: 1129 module(*example) 1130 elif isinstance(example_inputs, List): 1131 for examples in example_inputs: 1132 obj(*examples) 1133 else: 1134 raise ValueError( 1135 "Error: Unable to infer types. Please format the inputs to type `List[Tuple]`" 1136 " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType." 1137 ) 1138 else: 1139 warnings.warn( 1140 "Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " 1141 "to enable Profile-Directed Typing in TorchScript. Refer to " 1142 "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. " 1143 ) 1144 1145 if isinstance(obj, torch.nn.Module): 1146 obj = call_prepare_scriptable_func(obj) 1147 return torch.jit._recursive.create_script_module( 1148 obj, torch.jit._recursive.infer_methods_to_compile 1149 ) 1150 else: 1151 obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator] 1152 1153 if isinstance(obj, dict): 1154 return create_script_dict(obj) 1155 if isinstance(obj, list): 1156 return create_script_list(obj) 1157 1158 if inspect.isclass(obj): 1159 qualified_name = _qualified_name(obj) 1160 # If this type is a `nn.Module` subclass, they probably meant to pass 1161 # an instance instead of a Module 1162 if issubclass(obj, torch.nn.Module): 1163 raise RuntimeError( 1164 f"Type '{obj}' cannot be compiled since it inherits from nn.Module, pass an instance instead" 1165 ) 1166 1167 # Enums are automatically usable in TorchScript, explicitly scripting 1168 # is not necessary, but not harmful either. 1169 if issubclass(obj, enum.Enum): 1170 return obj 1171 1172 if not _is_new_style_class(obj): 1173 raise RuntimeError( 1174 "TorchScript classes must be new-style classes. " 1175 "Please inherit from 'object'." 1176 ) 1177 if len(obj.mro()) > 2: 1178 raise RuntimeError( 1179 "TorchScript classes does not support inheritance yet. " 1180 "Please directly inherit from 'object'." 1181 ) 1182 if _rcb is None: 1183 _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) 1184 _compile_and_register_class(obj, _rcb, qualified_name) 1185 return obj 1186 elif inspect.isfunction(obj) or inspect.ismethod(obj): 1187 qualified_name = _qualified_name(obj) 1188 # this is a decorated fn, and we need to the underlying fn and its rcb 1189 if hasattr(obj, "__script_if_tracing_wrapper"): 1190 obj = obj.__original_fn # type: ignore[union-attr] 1191 _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) 1192 1193 # some functions are explicitly marked as not supported in script mode 1194 if hasattr(obj, "__script_unsupported"): 1195 raise RuntimeError("TorchScript error: " + obj.__script_unsupported) 1196 1197 _check_directly_compile_overloaded(obj) 1198 maybe_already_compiled_fn = _try_get_jit_cached_function(obj) 1199 if maybe_already_compiled_fn: 1200 maybe_already_compiled_fn._torchdynamo_inline = obj # type: ignore[attr-defined] 1201 return maybe_already_compiled_fn 1202 ast = get_jit_def(obj, obj.__name__) 1203 if _rcb is None: 1204 _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) 1205 fn = torch._C._jit_script_compile( 1206 qualified_name, ast, _rcb, get_default_args(obj) 1207 ) 1208 # Forward docstrings 1209 fn.__doc__ = obj.__doc__ 1210 # Allow torch.compile() to inline 1211 fn._torchdynamo_inline = obj # type: ignore[attr-defined] 1212 _set_jit_function_cache(obj, fn) 1213 return fn 1214 else: 1215 return torch.jit._recursive.create_script_class(obj) 1216 1217 1218def script( 1219 obj, 1220 optimize=None, 1221 _frames_up=0, 1222 _rcb=None, 1223 example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None, 1224): 1225 r"""Script the function. 1226 1227 Scripting a function or ``nn.Module`` will inspect the source code, compile 1228 it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or 1229 :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all 1230 features in Python work, but we provide enough functionality to compute on 1231 tensors and do control-dependent operations. For a complete guide, see the 1232 :ref:`language-reference`. 1233 1234 Scripting a dictionary or list copies the data inside it into a TorchScript instance than can be 1235 subsequently passed by reference between Python and TorchScript with zero copy overhead. 1236 1237 ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists 1238 and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. 1239 1240 Args: 1241 obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type, 1242 dictionary, or list to compile. 1243 example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs 1244 to annotate the arguments for a function or ``nn.Module``. 1245 1246 Returns: 1247 If ``obj`` is ``nn.Module``, ``script`` returns 1248 a :class:`ScriptModule` object. The returned :class:`ScriptModule` will 1249 have the same set of sub-modules and parameters as the 1250 original ``nn.Module``. If ``obj`` is a standalone function, 1251 a :class:`ScriptFunction` will be returned. If ``obj`` is a ``dict``, then 1252 ``script`` returns an instance of `torch._C.ScriptDict`. If ``obj`` is a ``list``, 1253 then ``script`` returns an instance of `torch._C.ScriptList`. 1254 1255 **Scripting a function** 1256 The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction` 1257 by compiling the body of the function. 1258 1259 Example (scripting a function): 1260 1261 .. testcode:: 1262 1263 import torch 1264 1265 @torch.jit.script 1266 def foo(x, y): 1267 if x.max() > y.max(): 1268 r = x 1269 else: 1270 r = y 1271 return r 1272 1273 print(type(foo)) # torch.jit.ScriptFunction 1274 1275 # See the compiled graph as Python code 1276 print(foo.code) 1277 1278 # Call the function using the TorchScript interpreter 1279 foo(torch.ones(2, 2), torch.ones(2, 2)) 1280 1281 .. testoutput:: 1282 :hide: 1283 1284 ... 1285 1286 ****Scripting a function using example_inputs** 1287 Example inputs can be used to annotate a function arguments. 1288 1289 Example (annotating a function before scripting): 1290 1291 .. testcode:: 1292 1293 import torch 1294 1295 def test_sum(a, b): 1296 return a + b 1297 1298 # Annotate the arguments to be int 1299 scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) 1300 1301 print(type(scripted_fn)) # torch.jit.ScriptFunction 1302 1303 # See the compiled graph as Python code 1304 print(scripted_fn.code) 1305 1306 # Call the function using the TorchScript interpreter 1307 scripted_fn(20, 100) 1308 1309 .. testoutput:: 1310 :hide: 1311 1312 ... 1313 1314 **Scripting an nn.Module** 1315 Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively 1316 compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses 1317 features supported in TorchScript, no changes to the original module code should be necessary. ``script`` 1318 will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of 1319 the original module. 1320 1321 Example (scripting a simple module with a Parameter): 1322 1323 .. testcode:: 1324 1325 import torch 1326 1327 class MyModule(torch.nn.Module): 1328 def __init__(self, N, M): 1329 super().__init__() 1330 # This parameter will be copied to the new ScriptModule 1331 self.weight = torch.nn.Parameter(torch.rand(N, M)) 1332 1333 # When this submodule is used, it will be compiled 1334 self.linear = torch.nn.Linear(N, M) 1335 1336 def forward(self, input): 1337 output = self.weight.mv(input) 1338 1339 # This calls the `forward` method of the `nn.Linear` module, which will 1340 # cause the `self.linear` submodule to be compiled to a `ScriptModule` here 1341 output = self.linear(output) 1342 return output 1343 1344 scripted_module = torch.jit.script(MyModule(2, 3)) 1345 1346 Example (scripting a module with traced submodules): 1347 1348 .. testcode:: 1349 1350 import torch 1351 import torch.nn as nn 1352 import torch.nn.functional as F 1353 1354 class MyModule(nn.Module): 1355 def __init__(self) -> None: 1356 super().__init__() 1357 # torch.jit.trace produces a ScriptModule's conv1 and conv2 1358 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) 1359 self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) 1360 1361 def forward(self, input): 1362 input = F.relu(self.conv1(input)) 1363 input = F.relu(self.conv2(input)) 1364 return input 1365 1366 scripted_module = torch.jit.script(MyModule()) 1367 1368 To compile a method other than ``forward`` (and recursively compile anything it calls), add 1369 the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation 1370 use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`. 1371 1372 Example (an exported and ignored method in a module):: 1373 1374 import torch 1375 import torch.nn as nn 1376 1377 class MyModule(nn.Module): 1378 def __init__(self) -> None: 1379 super().__init__() 1380 1381 @torch.jit.export 1382 def some_entry_point(self, input): 1383 return input + 10 1384 1385 @torch.jit.ignore 1386 def python_only_fn(self, input): 1387 # This function won't be compiled, so any 1388 # Python APIs can be used 1389 import pdb 1390 pdb.set_trace() 1391 1392 def forward(self, input): 1393 if self.training: 1394 self.python_only_fn(input) 1395 return input * 99 1396 1397 scripted_module = torch.jit.script(MyModule()) 1398 print(scripted_module.some_entry_point(torch.randn(2, 2))) 1399 print(scripted_module(torch.randn(2, 2))) 1400 1401 Example ( Annotating forward of nn.Module using example_inputs):: 1402 1403 import torch 1404 import torch.nn as nn 1405 from typing import NamedTuple 1406 1407 class MyModule(NamedTuple): 1408 result: List[int] 1409 1410 class TestNNModule(torch.nn.Module): 1411 def forward(self, a) -> MyModule: 1412 result = MyModule(result=a) 1413 return result 1414 1415 pdt_model = TestNNModule() 1416 1417 # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward 1418 scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) 1419 1420 # Run the scripted_model with actual inputs 1421 print(scripted_model([20])) 1422 """ 1423 if not _enabled: 1424 return obj 1425 try: 1426 global _TOPLEVEL 1427 prev = _TOPLEVEL 1428 _TOPLEVEL = False 1429 ret = _script_impl( 1430 obj=obj, 1431 optimize=optimize, 1432 _frames_up=_frames_up + 1, 1433 _rcb=_rcb, 1434 example_inputs=example_inputs, 1435 ) 1436 1437 if prev: 1438 log_torchscript_usage("script", model_id=_get_model_id(ret)) 1439 1440 return ret 1441 finally: 1442 _TOPLEVEL = prev 1443 1444 1445# overloads are registered in _jit_internal and compiled here so that _overload 1446# can be used in nn/functional.py without an import cycle 1447 1448 1449def _check_overload_defaults(impl_defaults, overload_defaults, loc): 1450 for name, overload_value in overload_defaults.items(): 1451 if name not in impl_defaults or impl_defaults[name] != overload_value: 1452 raise torch.jit.frontend.FrontendError( 1453 loc, 1454 "Default parameters on overloads do not affect the runtime so they " 1455 "must equal to the default parameter on the implementation function. Found on " 1456 f"parameter {name}", 1457 ) 1458 1459 1460def _compile_function_with_overload(overload_fn, qual_name, impl_fn): 1461 overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl() 1462 overload_signature = torch.jit.annotations.get_signature( 1463 overload_fn, None, None, inspect.ismethod(overload_fn) 1464 ) 1465 impl_ast = get_jit_def(impl_fn, impl_fn.__name__) 1466 overload_defaults = get_default_args(overload_fn) 1467 implementation_defaults = get_default_args(impl_fn) 1468 _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn) 1469 _check_overload_defaults( 1470 implementation_defaults, overload_defaults, overload_decl.range() 1471 ) 1472 fn = torch._C._jit_script_compile_overload( 1473 qual_name, 1474 overload_decl, 1475 impl_ast, 1476 _rcb, 1477 implementation_defaults, 1478 overload_signature, 1479 ) 1480 return fn 1481 1482 1483def _get_overloads(obj): 1484 # check for cached compiled fns 1485 existing_compiled_fns = _try_get_jit_cached_overloads(obj) 1486 qual_name = _qualified_name(obj) 1487 uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name) 1488 if uncompiled_overloads is None: 1489 return existing_compiled_fns 1490 1491 if obj in uncompiled_overloads: 1492 raise RuntimeError( 1493 _jit_internal.get_overload_no_implementation_error_message("function", obj) 1494 ) 1495 1496 compiled_fns = [] 1497 for overload_fn in uncompiled_overloads: 1498 compiled_fns.append( 1499 _compile_function_with_overload(overload_fn, qual_name, obj) 1500 ) 1501 1502 if existing_compiled_fns: 1503 compiled_fns = existing_compiled_fns + compiled_fns 1504 1505 # cache compilation, remove information stored to do compilation 1506 _set_jit_overload_cache(obj, compiled_fns) 1507 _jit_internal._clear_fn_overloads(qual_name) 1508 return compiled_fns 1509 1510 1511def _check_directly_compile_overloaded(obj): 1512 qual_name = _qualified_name(obj) 1513 if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj): 1514 raise RuntimeError( 1515 f"Function {qual_name} cannot be directly compiled because it" 1516 " is overloaded. It must be used in a context of a function" 1517 " where its inputs can determine which overload to call." 1518 ) 1519 1520 1521def interface(obj): 1522 r"""Decorate to annotate classes or modules of different types. 1523 1524 This decorator can be used to define an interface that can be used to annotate 1525 classes or modules of different types. This can be used for to annotate a submodule 1526 or attribute class that could have different types that implement the same 1527 interface, or which could be swapped at runtime; or to store a list of modules or 1528 classes of varying types. 1529 1530 It is sometimes used to implement "Callables" - functions or modules that implement 1531 an interface but whose implementations differ and which can be swapped out. 1532 1533 Example: 1534 .. testcode:: 1535 1536 import torch 1537 from typing import List 1538 1539 @torch.jit.interface 1540 class InterfaceType: 1541 def run(self, x: torch.Tensor) -> torch.Tensor: 1542 pass 1543 1544 # implements InterfaceType 1545 @torch.jit.script 1546 class Impl1: 1547 def run(self, x: torch.Tensor) -> torch.Tensor: 1548 return x.relu() 1549 1550 class Impl2(torch.nn.Module): 1551 def __init__(self) -> None: 1552 super().__init__() 1553 self.val = torch.rand(()) 1554 1555 @torch.jit.export 1556 def run(self, x: torch.Tensor) -> torch.Tensor: 1557 return x + self.val 1558 1559 def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor: 1560 return impls[idx].run(val) 1561 1562 user_fn_jit = torch.jit.script(user_fn) 1563 1564 impls = [Impl1(), torch.jit.script(Impl2())] 1565 val = torch.rand(4, 4) 1566 user_fn_jit(impls, 0, val) 1567 user_fn_jit(impls, 1, val) 1568 """ 1569 if not inspect.isclass(obj): 1570 raise RuntimeError("interface must be applied to a class") 1571 if not _is_new_style_class(obj): 1572 raise RuntimeError("TorchScript interfaces must inherit from 'object'") 1573 1574 # Expected MRO is: 1575 # User module 1576 # torch.nn.modules.module.Module 1577 # object 1578 is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3 1579 1580 if not is_module_interface and len(obj.mro()) > 2: 1581 raise RuntimeError( 1582 "TorchScript interface does not support inheritance yet. " 1583 "Please directly inherit from 'object' or 'nn.Module'." 1584 ) 1585 1586 qualified_name = _qualified_name(obj) 1587 rcb = _jit_internal.createResolutionCallbackFromFrame(1) 1588 # if this type is a `nn.Module` subclass, generate a module interface type 1589 # instead of a class interface type; a module interface type only compiles 1590 # the user provided methods as part of the interface 1591 ast = get_jit_class_def(obj, obj.__name__) 1592 mangled_classname = torch._C._jit_script_interface_compile( 1593 qualified_name, ast, rcb, is_module_interface 1594 ) 1595 obj.__torch_script_interface__ = mangled_classname 1596 return obj 1597 1598 1599def _recursive_compile_class(obj, loc): 1600 _qual_name = _qualified_name(obj) 1601 # We're starting a new compilation, so update the error call stack in 1602 # case it fails 1603 error_stack = torch._C.CallStack(_qual_name, loc) 1604 rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) 1605 return _compile_and_register_class(obj, rcb, _qual_name) 1606 1607 1608CompilationUnit = torch._C.CompilationUnit 1609set_module(CompilationUnit, "torch.jit") 1610 1611 1612def pad(s: str, padding: int, offset: int = 0, char: str = " "): 1613 if padding >= len(s): 1614 padding -= len(s) 1615 return "".join([char for _ in range(padding + offset)]) + s 1616 1617 1618class _ScriptProfileColumn: 1619 def __init__(self, header: str, alignment: int = 4, offset: int = 0): 1620 self.header = header 1621 self.alignment = alignment 1622 self.offset = offset 1623 self.rows: Dict[int, Any] = {} 1624 1625 def add_row(self, lineno: int, value: Any): 1626 self.rows[lineno] = value 1627 1628 def materialize(self): 1629 max_length = len(self.header) 1630 rows: List[Tuple[int, str]] = [] 1631 for key, value in self.rows.items(): 1632 cell = str(value) 1633 rows.append((key, cell)) 1634 max_length = max(len(cell), max_length) 1635 1636 if self.alignment > 0: 1637 padding = max_length + self.alignment 1638 padding -= padding % self.alignment 1639 else: 1640 padding = 0 1641 1642 rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows] 1643 return pad(self.header, padding, self.offset), rows 1644 1645 1646class _ScriptProfileTable: 1647 def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]): 1648 self.cols = cols 1649 self.source_range = source_range 1650 1651 def dump_string(self): 1652 outputs: List[str] = [] 1653 cells: List[Tuple[str, Dict[int, str]]] = [] 1654 header_buffer = "" 1655 for col in self.cols: 1656 header, rows = col.materialize() 1657 header_buffer += header 1658 cells.append((header, dict(rows))) 1659 1660 outputs.append(header_buffer) 1661 outputs.append(pad("", len(header_buffer), 0, "=")) 1662 for line in self.source_range: 1663 row_buffer = "" 1664 for header, rows in cells: 1665 cell = rows.get(line) 1666 if cell is None: 1667 row_buffer += pad("", len(header)) 1668 else: 1669 row_buffer += cell 1670 outputs.append(row_buffer) 1671 return "\n".join(outputs) 1672 1673 1674class _ScriptProfile: 1675 def __init__(self) -> None: 1676 self.profile = classes.profiling._ScriptProfile() 1677 1678 def enable(self): 1679 self.profile.enable() 1680 1681 def disable(self): 1682 self.profile.disable() 1683 1684 def dump_string(self) -> str: 1685 outputs: List[str] = [] 1686 for source_stats in self.profile._dump_stats(): 1687 source_ref = source_stats.source() 1688 source_lines = source_ref.text().splitlines() 1689 dedent = min(len(line) - len(line.lstrip(" ")) for line in source_lines) 1690 source_lines = [line[dedent:] for line in source_lines] 1691 1692 start_line = source_ref.starting_lineno() 1693 end_line = start_line + len(source_lines) 1694 source_range = range(start_line, end_line) 1695 lineno = _ScriptProfileColumn("Line #") 1696 hits = _ScriptProfileColumn("Hits") 1697 time_ns = _ScriptProfileColumn("Time (ns)") 1698 line_contents = _ScriptProfileColumn("Line Contents", 0, 1) 1699 stats = source_stats.line_map() 1700 for line in source_range: 1701 lineno.add_row(line, line) 1702 line_contents.add_row(line, source_lines[line - start_line]) 1703 stat = stats.get(line) 1704 if stat is not None: 1705 hits.add_row(line, stat.count()) 1706 time_ns.add_row(line, stat.duration_ns()) 1707 1708 table = _ScriptProfileTable( 1709 [lineno, hits, time_ns, line_contents], list(source_range) 1710 ) 1711 outputs.append(table.dump_string()) 1712 return "\n\n".join(outputs) 1713 1714 def dump(self): 1715 print(self.dump_string()) 1716 1717 1718def _unwrap_optional(x): 1719 assert x is not None, "Unwrapping null optional" 1720 return x 1721 1722 1723_register_builtin(_unwrap_optional, "aten::_unwrap_optional") 1724_register_builtin(_jit_internal.is_scripting, "aten::is_scripting") 1725_register_builtin(has_torch_function, "aten::has_torch_function") 1726_register_builtin(has_torch_function_unary, "aten::has_torch_function") 1727_register_builtin(has_torch_function_variadic, "aten::has_torch_function") 1728