1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import collections 4import copyreg 5from contextlib import contextmanager 6from copy import deepcopy 7from typing import Dict, Optional, Sequence, Tuple, Union 8 9import torch 10from torch import Tensor 11from torch.__future__ import get_swap_module_params_on_conversion 12from torch.nn.modules.container import Module, ModuleDict, ModuleList 13from torch.nn.parameter import Parameter 14from torch.utils._python_dispatch import is_traceable_wrapper_subclass 15 16 17__all__ = [ 18 "cached", 19 "ParametrizationList", 20 "register_parametrization", 21 "is_parametrized", 22 "remove_parametrizations", 23 "type_before_parametrizations", 24 "transfer_parametrizations_and_params", 25] 26 27_cache_enabled = 0 28_cache: Dict[Tuple[int, str], Optional[Tensor]] = {} 29 30 31@contextmanager 32def cached(): 33 r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. 34 35 The value of the parametrized objects is computed and cached the first time 36 they are required when this context manager is active. The cached values are 37 discarded when leaving the context manager. 38 39 This is useful when using a parametrized parameter more than once in the forward pass. 40 An example of this is when parametrizing the recurrent kernel of an RNN or when 41 sharing weights. 42 43 The simplest way to activate the cache is by wrapping the forward pass of the neural network 44 45 .. code-block:: python 46 47 import torch.nn.utils.parametrize as P 48 ... 49 with P.cached(): 50 output = model(inputs) 51 52 in training and evaluation. One may also wrap the parts of the modules that use 53 several times the parametrized tensors. For example, the loop of an RNN with a 54 parametrized recurrent kernel: 55 56 .. code-block:: python 57 58 with P.cached(): 59 for x in xs: 60 out_rnn = self.rnn_cell(x, out_rnn) 61 """ 62 global _cache 63 global _cache_enabled 64 _cache_enabled += 1 65 try: 66 yield 67 finally: 68 _cache_enabled -= 1 69 if not _cache_enabled: 70 _cache = {} 71 72 73def _register_parameter_or_buffer(module, name, X): 74 if isinstance(X, Parameter): 75 module.register_parameter(name, X) 76 else: 77 module.register_buffer(name, X) 78 79 80def _maybe_set(dest: Tensor, src: Tensor) -> None: 81 should_swap = ( 82 get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest) 83 ) 84 if should_swap: 85 if isinstance(dest, Parameter) and not isinstance(src, Parameter): 86 src = Parameter(src, requires_grad=dest.requires_grad) 87 torch.utils.swap_tensors(dest, src) 88 else: 89 dest.set_(src) # type: ignore[call-overload] 90 91 92class ParametrizationList(ModuleList): 93 r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`. 94 95 It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` 96 has been parametrized with :func:`register_parametrization`. 97 98 If the first registered parametrization has a ``right_inverse`` that returns one tensor or 99 does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), 100 it will hold the tensor under the name ``original``. 101 If it has a ``right_inverse`` that returns more than one tensor, these will be registered as 102 ``original0``, ``original1``, ... 103 104 .. warning:: 105 This class is used internally by :func:`register_parametrization`. It is documented 106 here for completeness. It shall not be instantiated by the user. 107 108 Args: 109 modules (sequence): sequence of modules representing the parametrizations 110 original (Parameter or Tensor): parameter or buffer that is parametrized 111 unsafe (bool): a boolean flag that denotes whether the parametrization 112 may change the dtype and shape of the tensor. Default: `False` 113 Warning: the parametrization is not checked for consistency upon registration. 114 Enable this flag at your own risk. 115 """ 116 117 original: Tensor 118 unsafe: bool 119 120 def __init__( 121 self, 122 modules: Sequence[Module], 123 original: Union[Tensor, Parameter], 124 unsafe: bool = False, 125 ) -> None: 126 # We require this because we need to treat differently the first parametrization 127 # This should never throw, unless this class is used from the outside 128 if len(modules) == 0: 129 raise ValueError("ParametrizationList requires one or more modules.") 130 131 super().__init__(modules) 132 self.unsafe = unsafe 133 134 # In plain words: 135 # module.weight must keep its dtype and shape. 136 # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, 137 # this should be of the same dtype as the original tensor 138 # 139 # We check that the following invariants hold: 140 # X = module.weight 141 # Y = param.right_inverse(X) 142 # assert isinstance(Y, Tensor) or 143 # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) 144 # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) 145 # # Consistency checks 146 # assert X.dtype == Z.dtype and X.shape == Z.shape 147 # # If it has one input, this allows to be able to use set_ to be able to 148 # # move data to/from the original tensor without changing its id (which is what the 149 # # optimizer uses to track parameters) 150 # if isinstance(Y, Tensor) 151 # assert X.dtype == Y.dtype 152 # Below we use original = X, new = Y 153 154 original_shape = original.shape 155 original_dtype = original.dtype 156 157 # Compute new 158 with torch.no_grad(): 159 new = original 160 for module in reversed(self): # type: ignore[call-overload] 161 if hasattr(module, "right_inverse"): 162 try: 163 new = module.right_inverse(new) 164 except NotImplementedError: 165 pass 166 # else, or if it throws, we assume that right_inverse is the identity 167 168 if not isinstance(new, Tensor) and not isinstance( 169 new, collections.abc.Sequence 170 ): 171 raise ValueError( 172 "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " 173 f"Got {type(new).__name__}" 174 ) 175 176 # Set the number of original tensors 177 self.is_tensor = isinstance(new, Tensor) 178 self.ntensors = 1 if self.is_tensor else len(new) 179 180 # Register the tensor(s) 181 if self.is_tensor: 182 if original.dtype != new.dtype: 183 raise ValueError( 184 "When `right_inverse` outputs one tensor, it may not change the dtype.\n" 185 f"original.dtype: {original.dtype}\n" 186 f"right_inverse(original).dtype: {new.dtype}" 187 ) 188 # Set the original to original so that the user does not need to re-register the parameter 189 # manually in the optimiser 190 with torch.no_grad(): 191 _maybe_set(original, new) 192 _register_parameter_or_buffer(self, "original", original) 193 else: 194 for i, originali in enumerate(new): 195 if not isinstance(originali, Tensor): 196 raise ValueError( 197 "'right_inverse' must return a Tensor or a Sequence of tensors " 198 "(list, tuple...). " 199 f"Got element {i} of the sequence with type {type(originali).__name__}." 200 ) 201 202 # If the original tensor was a Parameter that required grad, we expect the user to 203 # add the new parameters to the optimizer after registering the parametrization 204 # (this is documented) 205 if isinstance(original, Parameter): 206 originali = Parameter(originali, original.requires_grad) 207 originali.requires_grad_(original.requires_grad) 208 _register_parameter_or_buffer(self, f"original{i}", originali) 209 210 if not self.unsafe: 211 # Consistency checks: 212 # Since f : A -> B, right_inverse : B -> A, Z and original should live in B 213 # Z = forward(right_inverse(original)) 214 Z = self() 215 if not isinstance(Z, Tensor): 216 raise ValueError( 217 f"A parametrization must return a tensor. Got {type(Z).__name__}." 218 ) 219 if Z.dtype != original_dtype: 220 raise ValueError( 221 "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" 222 f"unparametrized dtype: {original_dtype}\n" 223 f"parametrized dtype: {Z.dtype}" 224 ) 225 if Z.shape != original_shape: 226 raise ValueError( 227 "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" 228 f"unparametrized shape: {original_shape}\n" 229 f"parametrized shape: {Z.shape}" 230 ) 231 232 def right_inverse(self, value: Tensor) -> None: 233 r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order. 234 235 Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor 236 or in ``self.original0``, ``self.original1``, ... if it outputs several. 237 238 Args: 239 value (Tensor): Value to which initialize the module 240 """ 241 # All the exceptions in this function should almost never throw. 242 # They could throw if, for example, right_inverse function returns a different 243 # dtype when given a different input, which should most likely be caused by a 244 # bug in the user's code 245 246 with torch.no_grad(): 247 # See https://github.com/pytorch/pytorch/issues/53103 248 for module in reversed(self): # type: ignore[call-overload] 249 if hasattr(module, "right_inverse"): 250 value = module.right_inverse(value) 251 else: 252 raise RuntimeError( 253 f"parametrization {type(module).__name__} does not implement " 254 "right_inverse." 255 ) 256 if self.is_tensor: 257 # These exceptions should only throw when a right_inverse function does not 258 # return the same dtype for every input, which should most likely be caused by a bug 259 if not isinstance(value, Tensor): 260 raise ValueError( 261 f"`right_inverse` should return a tensor. Got {type(value).__name__}" 262 ) 263 if value.dtype != self.original.dtype: 264 raise ValueError( 265 f"The tensor returned by `right_inverse` has dtype {value.dtype} " 266 f"while `original` has dtype {self.original.dtype}" 267 ) 268 # We know that the result is going to have the same dtype 269 _maybe_set(self.original, value) 270 else: 271 if not isinstance(value, collections.abc.Sequence): 272 raise ValueError( 273 "'right_inverse' must return a sequence of tensors. " 274 f"Got {type(value).__name__}." 275 ) 276 if len(value) != self.ntensors: 277 raise ValueError( 278 "'right_inverse' must return a sequence of tensors of length " 279 f"{self.ntensors}. Got a sequence of length {len(value)}." 280 ) 281 for i, tensor in enumerate(value): 282 original_i = getattr(self, f"original{i}") 283 if not isinstance(tensor, Tensor): 284 raise ValueError( 285 f"`right_inverse` must return a sequence of tensors. " 286 f"Got element {i} of type {type(tensor).__name__}" 287 ) 288 if original_i.dtype != tensor.dtype: 289 raise ValueError( 290 f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " 291 f"while `original{i}` has dtype {original_i.dtype}" 292 ) 293 _maybe_set(original_i, tensor) 294 295 def forward(self) -> Tensor: 296 if torch.jit.is_scripting(): 297 raise RuntimeError("Parametrization is not working with scripting.") 298 # Unpack the originals for the first parametrization 299 if self.is_tensor: 300 x = self[0](self.original) 301 else: 302 originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) 303 x = self[0](*originals) 304 # It's not possible to call self[1:] here, so we have to be a bit more cryptic 305 # Also we want to skip all non-integer keys 306 curr_idx = 1 307 while hasattr(self, str(curr_idx)): 308 x = self[curr_idx](x) 309 curr_idx += 1 310 return x 311 312 313def _inject_new_class(module: Module) -> None: 314 r"""Set up a module to be parametrized. 315 316 This works by substituting the class of the module by a class 317 that extends it to be able to inject a property 318 319 Args: 320 module (nn.Module): module into which to inject the property 321 """ 322 cls = module.__class__ 323 324 def default_deepcopy(self, memo): 325 # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. 326 obj = memo.get(id(self), None) 327 if obj is not None: 328 return obj 329 replica = self.__new__(self.__class__) 330 memo[id(self)] = replica 331 replica.__dict__ = deepcopy(self.__dict__, memo) 332 # Also save all slots if they exist. 333 slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] 334 for slot in slots_to_save: 335 if hasattr(self, slot): 336 setattr(replica, slot, deepcopy(getattr(self, slot), memo)) 337 return replica 338 339 def getstate(self): 340 raise RuntimeError( 341 "Serialization of parametrized modules is only " 342 "supported through state_dict(). See:\n" 343 "https://pytorch.org/tutorials/beginner/saving_loading_models.html" 344 "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" 345 ) 346 347 dct = {"__getstate__": getstate} 348 # We don't allow serialization of parametrized modules but should still allow deepcopying. 349 # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. 350 if not hasattr(cls, "__deepcopy__"): 351 dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] 352 353 param_cls = type( 354 f"Parametrized{cls.__name__}", 355 (cls,), 356 dct, 357 ) 358 359 module.__class__ = param_cls 360 361 362def _inject_property(module: Module, tensor_name: str) -> None: 363 r"""Injects a property into module[tensor_name]. 364 365 It assumes that the class in the module has already been modified from its 366 original one using _inject_new_class and that the tensor under :attr:`tensor_name` 367 has already been moved out 368 369 Args: 370 module (nn.Module): module into which to inject the property 371 tensor_name (str): name of the name of the property to create 372 """ 373 # We check the precondition. 374 # This should never fire if register_parametrization is correctly implemented 375 assert not hasattr(module, tensor_name) 376 377 @torch.jit.unused 378 def get_cached_parametrization(parametrization) -> Tensor: 379 global _cache 380 key = (id(module), tensor_name) 381 tensor = _cache.get(key) 382 if tensor is None: 383 tensor = parametrization() 384 _cache[key] = tensor 385 return tensor 386 387 def get_parametrized(self) -> Tensor: 388 if torch.jit.is_scripting(): 389 raise RuntimeError("Parametrization is not working with scripting.") 390 parametrization = self.parametrizations[tensor_name] 391 if _cache_enabled: 392 if torch.jit.is_scripting(): 393 # Scripting 394 raise RuntimeError( 395 "Caching is not implemented for scripting. " 396 "Either disable caching or avoid scripting." 397 ) 398 elif torch._C._get_tracing_state() is not None: 399 # Tracing 400 raise RuntimeError( 401 "Cannot trace a model while caching parametrizations." 402 ) 403 else: 404 return get_cached_parametrization(parametrization) 405 else: 406 # If caching is not active, this function just evaluates the parametrization 407 return parametrization() 408 409 def set_original(self, value: Tensor) -> None: 410 if torch.jit.is_scripting(): 411 raise RuntimeError("Parametrization is not working with scripting.") 412 self.parametrizations[tensor_name].right_inverse(value) 413 414 setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) 415 416 417def register_parametrization( 418 module: Module, 419 tensor_name: str, 420 parametrization: Module, 421 *, 422 unsafe: bool = False, 423) -> Module: 424 r"""Register a parametrization to a tensor in a module. 425 426 Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, 427 the module will return the parametrized version ``parametrization(module.weight)``. 428 If the original tensor requires a gradient, the backward pass will differentiate 429 through :attr:`parametrization`, and the optimizer will update the tensor accordingly. 430 431 The first time that a module registers a parametrization, this function will add an attribute 432 ``parametrizations`` to the module of type :class:`~ParametrizationList`. 433 434 The list of parametrizations on the tensor ``weight`` will be accessible under 435 ``module.parametrizations.weight``. 436 437 The original tensor will be accessible under 438 ``module.parametrizations.weight.original``. 439 440 Parametrizations may be concatenated by registering several parametrizations 441 on the same attribute. 442 443 The training mode of a registered parametrization is updated on registration 444 to match the training mode of the host module 445 446 Parametrized parameters and buffers have an inbuilt caching system that can be activated 447 using the context manager :func:`cached`. 448 449 A :attr:`parametrization` may optionally implement a method with signature 450 451 .. code-block:: python 452 453 def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] 454 455 This method is called on the unparametrized tensor when the first parametrization 456 is registered to compute the initial value of the original tensor. 457 If this method is not implemented, the original tensor will be just the unparametrized tensor. 458 459 If all the parametrizations registered on a tensor implement `right_inverse` it is possible 460 to initialize a parametrized tensor by assigning to it, as shown in the example below. 461 462 It is possible for the first parametrization to depend on several inputs. 463 This may be implemented returning a tuple of tensors from ``right_inverse`` 464 (see the example implementation of a ``RankOne`` parametrization below). 465 466 In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` 467 with names ``original0``, ``original1``,... 468 469 .. note:: 470 471 If unsafe=False (default) both the forward and right_inverse methods will be called 472 once to perform a number of consistency checks. 473 If unsafe=True, then right_inverse will be called if the tensor is not parametrized, 474 and nothing will be called otherwise. 475 476 .. note:: 477 478 In most situations, ``right_inverse`` will be a function such that 479 ``forward(right_inverse(X)) == X`` (see 480 `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_). 481 Sometimes, when the parametrization is not surjective, it may be reasonable 482 to relax this. 483 484 .. warning:: 485 486 If a parametrization depends on several inputs, :func:`~register_parametrization` 487 will register a number of new parameters. If such parametrization is registered 488 after the optimizer is created, these new parameters will need to be added manually 489 to the optimizer. See :meth:`torch.Optimizer.add_param_group`. 490 491 Args: 492 module (nn.Module): module on which to register the parametrization 493 tensor_name (str): name of the parameter or buffer on which to register 494 the parametrization 495 parametrization (nn.Module): the parametrization to register 496 Keyword args: 497 unsafe (bool): a boolean flag that denotes whether the parametrization 498 may change the dtype and shape of the tensor. Default: `False` 499 Warning: the parametrization is not checked for consistency upon registration. 500 Enable this flag at your own risk. 501 502 Raises: 503 ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` 504 505 Examples: 506 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 507 >>> import torch 508 >>> import torch.nn as nn 509 >>> import torch.nn.utils.parametrize as P 510 >>> 511 >>> class Symmetric(nn.Module): 512 >>> def forward(self, X): 513 >>> return X.triu() + X.triu(1).T # Return a symmetric matrix 514 >>> 515 >>> def right_inverse(self, A): 516 >>> return A.triu() 517 >>> 518 >>> m = nn.Linear(5, 5) 519 >>> P.register_parametrization(m, "weight", Symmetric()) 520 >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric 521 True 522 >>> A = torch.rand(5, 5) 523 >>> A = A + A.T # A is now symmetric 524 >>> m.weight = A # Initialize the weight to be the symmetric matrix A 525 >>> print(torch.allclose(m.weight, A)) 526 True 527 528 >>> class RankOne(nn.Module): 529 >>> def forward(self, x, y): 530 >>> # Form a rank 1 matrix multiplying two vectors 531 >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) 532 >>> 533 >>> def right_inverse(self, Z): 534 >>> # Project Z onto the rank 1 matrices 535 >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) 536 >>> # Return rescaled singular vectors 537 >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) 538 >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt 539 >>> 540 >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) 541 >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 542 1 543 544 """ 545 parametrization.train(module.training) 546 if is_parametrized(module, tensor_name): 547 # Correctness checks. 548 # If A is the space of tensors with shape and dtype equal to module.weight 549 # we check that parametrization.forward and parametrization.right_inverse are 550 # functions from A to A 551 if not unsafe: 552 Y = getattr(module, tensor_name) 553 X = parametrization(Y) 554 if not isinstance(X, Tensor): 555 raise ValueError( 556 f"A parametrization must return a tensor. Got {type(X).__name__}." 557 ) 558 if X.dtype != Y.dtype: 559 raise ValueError( 560 "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" 561 f"module.{tensor_name}.dtype: {Y.dtype}\n" 562 f"parametrization(module.{tensor_name}).dtype: {X.dtype}" 563 ) 564 if X.shape != Y.shape: 565 raise ValueError( 566 "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" 567 f"module.{tensor_name}.shape: {Y.shape}\n" 568 f"parametrization(module.{tensor_name}).shape: {X.shape}" 569 ) 570 if hasattr(parametrization, "right_inverse"): 571 try: 572 Z = parametrization.right_inverse(X) # type: ignore[operator] 573 except NotImplementedError: 574 pass 575 else: 576 if not isinstance(Z, Tensor): 577 raise ValueError( 578 f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" 579 ) 580 if Z.dtype != Y.dtype: 581 raise ValueError( 582 "The tensor returned by parametrization.right_inverse must have the same dtype " 583 f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" 584 f"module.{tensor_name}.dtype: {Y.dtype}\n" 585 f"returned dtype: {Z.dtype}" 586 ) 587 if Z.shape != Y.shape: 588 raise ValueError( 589 "The tensor returned by parametrization.right_inverse must have the same shape " 590 f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" 591 f"module.{tensor_name}.shape: {Y.shape}\n" 592 f"returned shape: {Z.shape}" 593 ) 594 # else right_inverse is assumed to be the identity 595 596 # add the new parametrization to the parametrization list 597 assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy 598 module.parametrizations[tensor_name].append(parametrization) 599 # If unsafe was True in previous parametrization, keep it enabled 600 module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr] 601 elif tensor_name in module._buffers or tensor_name in module._parameters: 602 # Set the parametrization mechanism 603 # Fetch the original buffer or parameter 604 original = getattr(module, tensor_name) 605 # We create this early to check for possible errors 606 parametrizations = ParametrizationList( 607 [parametrization], original, unsafe=unsafe 608 ) 609 # Delete the previous parameter or buffer 610 delattr(module, tensor_name) 611 # If this is the first parametrization registered on the module, 612 # we prepare the module to inject the property 613 if not is_parametrized(module): 614 # Change the class 615 _inject_new_class(module) 616 # Inject a ``ModuleDict`` into the instance under module.parametrizations 617 module.parametrizations = ModuleDict() 618 # Add a property into the class 619 _inject_property(module, tensor_name) 620 # Add a ParametrizationList 621 assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy 622 module.parametrizations[tensor_name] = parametrizations 623 else: 624 raise ValueError( 625 f"Module '{module}' does not have a parameter, a buffer, or a " 626 f"parametrized element with name '{tensor_name}'" 627 ) 628 return module 629 630 631def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: 632 r"""Determine if a module has a parametrization. 633 634 Args: 635 module (nn.Module): module to query 636 tensor_name (str, optional): name of the parameter in the module 637 Default: ``None`` 638 Returns: 639 ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, 640 or if it has any parametrization when :attr:`tensor_name` is ``None``; 641 otherwise ``False`` 642 """ 643 parametrizations = getattr(module, "parametrizations", None) 644 if parametrizations is None or not isinstance(parametrizations, ModuleDict): 645 return False 646 if tensor_name is None: 647 # Check that there is at least one parametrized buffer or Parameter 648 return len(parametrizations) > 0 649 else: 650 return tensor_name in parametrizations 651 652 653def remove_parametrizations( 654 module: Module, 655 tensor_name: str, 656 leave_parametrized: bool = True, 657) -> Module: 658 r"""Remove the parametrizations on a tensor in a module. 659 660 - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to 661 its current output. In this case, the parametrization shall not change the ``dtype`` 662 of the tensor. 663 - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to 664 the unparametrised tensor in ``module.parametrizations[tensor_name].original``. 665 This is only possible when the parametrization depends on just one tensor. 666 667 Args: 668 module (nn.Module): module from which remove the parametrization 669 tensor_name (str): name of the parametrization to be removed 670 leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. 671 Default: ``True`` 672 673 Returns: 674 Module: module 675 676 Raises: 677 ValueError: if ``module[tensor_name]`` is not parametrized 678 ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors 679 """ 680 if not is_parametrized(module, tensor_name): 681 raise ValueError( 682 f"Module {module} does not have a parametrization on {tensor_name}" 683 ) 684 685 # Fetch the original tensor 686 assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy 687 parametrizations = module.parametrizations[tensor_name] 688 if parametrizations.is_tensor: 689 original = parametrizations.original 690 if leave_parametrized: 691 with torch.no_grad(): 692 t = getattr(module, tensor_name) 693 # We know they have the same dtype because we have checked this when registering the 694 # parametrizations. As such, we can use set_ 695 # We do this so that the parameter does not to change the id() 696 # This way the user does not need to update the optimizer 697 with torch.no_grad(): 698 if type(original) is torch.Tensor: 699 _maybe_set(original, t) 700 else: 701 try: 702 _maybe_set(original, t) 703 except RuntimeError as e: 704 # TODO: Fix this for tensor subclasses that are parameters: 705 # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). 706 raise RuntimeError( 707 "Calling remove_parametrizations() with leave_parametrized=True " 708 "for a parameter that is an instance of a tensor subclass requires " 709 "set_() to be implemented correctly for the tensor subclass." 710 "Alternatively, one can opt into the swap_tensors path" 711 "Either set leave_parametrized=False or provide a working implementation" 712 "for set_() in the tensor subclass or set " 713 "torch.__future__.set_swap_module_params_on_conversion(True)." 714 ) from e 715 else: 716 if leave_parametrized: 717 # We cannot use no_grad because we need to know whether one or more 718 # original tensors required grad 719 t = getattr(module, tensor_name) 720 # We'll have to trust the user to add it to the optimizer 721 original = Parameter(t) if t.requires_grad else t 722 else: 723 raise ValueError( 724 "Cannot leave unparametrized (`leave_parametrized=False`) a tensor " 725 "that is parametrized in terms of a sequence of tensors." 726 ) 727 728 # Delete the property that manages the parametrization 729 delattr(module.__class__, tensor_name) 730 # Delete the ParametrizationList 731 del module.parametrizations[tensor_name] 732 733 # Restore the parameter / buffer into the main class 734 _register_parameter_or_buffer(module, tensor_name, original) 735 736 # Roll back the parametrized class if no other buffer or parameter 737 # is currently parametrized in this class 738 if not is_parametrized(module): 739 delattr(module, "parametrizations") 740 # Restore class 741 orig_cls = module.__class__.__bases__[0] 742 module.__class__ = orig_cls 743 return module 744 745 746def type_before_parametrizations(module: Module) -> type: 747 r"""Return the module type before parametrizations were applied and if not, then it returns the module type. 748 749 Args: 750 module (nn.Module): module to get type of 751 """ 752 if is_parametrized(module): 753 return module.__class__.__bases__[0] 754 else: 755 return type(module) 756 757 758def transfer_parametrizations_and_params( 759 from_module: Module, 760 to_module: Module, 761 tensor_name: Optional[str] = None, 762) -> Module: 763 r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`. 764 765 If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise 766 transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. 767 Does nothing if from_module is not parametrized. 768 769 Args: 770 from_module (nn.Module): module to transfer from 771 to_module (nn.Module): module to transfer to 772 tensor_name (str, optional): parameter to transfer 773 774 Returns: 775 Module: to_module 776 """ 777 if is_parametrized(from_module): 778 assert isinstance(from_module.parametrizations, ModuleDict) # for mypy 779 780 # get list of all params or the single param to transfer 781 parameters_to_transfer: Union[list, ModuleDict] = ( 782 from_module.parametrizations if tensor_name is None else [tensor_name] 783 ) 784 785 assert hasattr(parameters_to_transfer, "__iter__") # for mypy 786 for parameter_name in parameters_to_transfer: 787 # initialize the to-be-transferred param in to_module if it doesn't exist already 788 if not hasattr(to_module, parameter_name): 789 setattr( 790 to_module, 791 parameter_name, 792 Parameter(getattr(from_module, parameter_name)), 793 ) 794 795 # apply the params's parametrizations to to_module 796 for param_func in from_module.parametrizations[parameter_name]: 797 register_parametrization(to_module, parameter_name, param_func) 798 assert isinstance(to_module.parametrizations, ModuleDict) # for mypy 799 800 # make values match, original values can be stored in either original or 801 # original0, original1..., need to check both cases 802 if hasattr(from_module.parametrizations[parameter_name], "original"): 803 to_module.parametrizations[ 804 parameter_name 805 ].original = from_module.parametrizations[parameter_name].original 806 else: 807 num = 0 808 orig_num = "original" + str(num) 809 # loop through each original# until all values have been set 810 while hasattr(from_module.parametrizations[parameter_name], orig_num): 811 setattr( 812 to_module.parametrizations[parameter_name], 813 orig_num, 814 getattr(from_module.parametrizations[parameter_name], orig_num), 815 ) 816 num = num + 1 817 orig_num = "original" + str(num) 818 819 return to_module 820