xref: /aosp_15_r20/external/pytorch/torch/nn/utils/parametrize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import collections
4import copyreg
5from contextlib import contextmanager
6from copy import deepcopy
7from typing import Dict, Optional, Sequence, Tuple, Union
8
9import torch
10from torch import Tensor
11from torch.__future__ import get_swap_module_params_on_conversion
12from torch.nn.modules.container import Module, ModuleDict, ModuleList
13from torch.nn.parameter import Parameter
14from torch.utils._python_dispatch import is_traceable_wrapper_subclass
15
16
17__all__ = [
18    "cached",
19    "ParametrizationList",
20    "register_parametrization",
21    "is_parametrized",
22    "remove_parametrizations",
23    "type_before_parametrizations",
24    "transfer_parametrizations_and_params",
25]
26
27_cache_enabled = 0
28_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
29
30
31@contextmanager
32def cached():
33    r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`.
34
35    The value of the parametrized objects is computed and cached the first time
36    they are required when this context manager is active. The cached values are
37    discarded when leaving the context manager.
38
39    This is useful when using a parametrized parameter more than once in the forward pass.
40    An example of this is when parametrizing the recurrent kernel of an RNN or when
41    sharing weights.
42
43    The simplest way to activate the cache is by wrapping the forward pass of the neural network
44
45    .. code-block:: python
46
47        import torch.nn.utils.parametrize as P
48        ...
49        with P.cached():
50            output = model(inputs)
51
52    in training and evaluation. One may also wrap the parts of the modules that use
53    several times the parametrized tensors. For example, the loop of an RNN with a
54    parametrized recurrent kernel:
55
56    .. code-block:: python
57
58        with P.cached():
59            for x in xs:
60                out_rnn = self.rnn_cell(x, out_rnn)
61    """
62    global _cache
63    global _cache_enabled
64    _cache_enabled += 1
65    try:
66        yield
67    finally:
68        _cache_enabled -= 1
69        if not _cache_enabled:
70            _cache = {}
71
72
73def _register_parameter_or_buffer(module, name, X):
74    if isinstance(X, Parameter):
75        module.register_parameter(name, X)
76    else:
77        module.register_buffer(name, X)
78
79
80def _maybe_set(dest: Tensor, src: Tensor) -> None:
81    should_swap = (
82        get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest)
83    )
84    if should_swap:
85        if isinstance(dest, Parameter) and not isinstance(src, Parameter):
86            src = Parameter(src, requires_grad=dest.requires_grad)
87        torch.utils.swap_tensors(dest, src)
88    else:
89        dest.set_(src)  # type: ignore[call-overload]
90
91
92class ParametrizationList(ModuleList):
93    r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`.
94
95    It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]``
96    has been parametrized with :func:`register_parametrization`.
97
98    If the first registered parametrization has a ``right_inverse`` that returns one tensor or
99    does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity),
100    it will hold the tensor under the name ``original``.
101    If it has a ``right_inverse`` that returns more than one tensor, these will be registered as
102    ``original0``, ``original1``, ...
103
104    .. warning::
105        This class is used internally by :func:`register_parametrization`. It is documented
106        here for completeness. It shall not be instantiated by the user.
107
108    Args:
109        modules (sequence): sequence of modules representing the parametrizations
110        original (Parameter or Tensor): parameter or buffer that is parametrized
111        unsafe (bool): a boolean flag that denotes whether the parametrization
112            may change the dtype and shape of the tensor. Default: `False`
113            Warning: the parametrization is not checked for consistency upon registration.
114            Enable this flag at your own risk.
115    """
116
117    original: Tensor
118    unsafe: bool
119
120    def __init__(
121        self,
122        modules: Sequence[Module],
123        original: Union[Tensor, Parameter],
124        unsafe: bool = False,
125    ) -> None:
126        # We require this because we need to treat differently the first parametrization
127        # This should never throw, unless this class is used from the outside
128        if len(modules) == 0:
129            raise ValueError("ParametrizationList requires one or more modules.")
130
131        super().__init__(modules)
132        self.unsafe = unsafe
133
134        # In plain words:
135        # module.weight must keep its dtype and shape.
136        # Furthermore, if there is no right_inverse or the right_inverse returns a tensor,
137        # this should be of the same dtype as the original tensor
138        #
139        # We check that the following invariants hold:
140        #    X = module.weight
141        #    Y = param.right_inverse(X)
142        #    assert isinstance(Y, Tensor) or
143        #           (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))
144        #    Z = param(Y) if isinstance(Y, Tensor) else param(*Y)
145        #    # Consistency checks
146        #    assert X.dtype == Z.dtype and X.shape == Z.shape
147        #    # If it has one input, this allows to be able to use set_ to be able to
148        #    # move data to/from the original tensor without changing its id (which is what the
149        #    # optimizer uses to track parameters)
150        #    if isinstance(Y, Tensor)
151        #      assert X.dtype == Y.dtype
152        # Below we use original = X, new = Y
153
154        original_shape = original.shape
155        original_dtype = original.dtype
156
157        # Compute new
158        with torch.no_grad():
159            new = original
160            for module in reversed(self):  # type: ignore[call-overload]
161                if hasattr(module, "right_inverse"):
162                    try:
163                        new = module.right_inverse(new)
164                    except NotImplementedError:
165                        pass
166                # else, or if it throws, we assume that right_inverse is the identity
167
168        if not isinstance(new, Tensor) and not isinstance(
169            new, collections.abc.Sequence
170        ):
171            raise ValueError(
172                "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
173                f"Got {type(new).__name__}"
174            )
175
176        # Set the number of original tensors
177        self.is_tensor = isinstance(new, Tensor)
178        self.ntensors = 1 if self.is_tensor else len(new)
179
180        # Register the tensor(s)
181        if self.is_tensor:
182            if original.dtype != new.dtype:
183                raise ValueError(
184                    "When `right_inverse` outputs one tensor, it may not change the dtype.\n"
185                    f"original.dtype: {original.dtype}\n"
186                    f"right_inverse(original).dtype: {new.dtype}"
187                )
188            # Set the original to original so that the user does not need to re-register the parameter
189            # manually in the optimiser
190            with torch.no_grad():
191                _maybe_set(original, new)
192            _register_parameter_or_buffer(self, "original", original)
193        else:
194            for i, originali in enumerate(new):
195                if not isinstance(originali, Tensor):
196                    raise ValueError(
197                        "'right_inverse' must return a Tensor or a Sequence of tensors "
198                        "(list, tuple...). "
199                        f"Got element {i} of the sequence with type {type(originali).__name__}."
200                    )
201
202                # If the original tensor was a Parameter that required grad, we expect the user to
203                # add the new parameters to the optimizer after registering the parametrization
204                # (this is documented)
205                if isinstance(original, Parameter):
206                    originali = Parameter(originali, original.requires_grad)
207                originali.requires_grad_(original.requires_grad)
208                _register_parameter_or_buffer(self, f"original{i}", originali)
209
210        if not self.unsafe:
211            # Consistency checks:
212            # Since f : A -> B, right_inverse : B -> A, Z and original should live in B
213            # Z = forward(right_inverse(original))
214            Z = self()
215            if not isinstance(Z, Tensor):
216                raise ValueError(
217                    f"A parametrization must return a tensor. Got {type(Z).__name__}."
218                )
219            if Z.dtype != original_dtype:
220                raise ValueError(
221                    "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n"
222                    f"unparametrized dtype: {original_dtype}\n"
223                    f"parametrized dtype: {Z.dtype}"
224                )
225            if Z.shape != original_shape:
226                raise ValueError(
227                    "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n"
228                    f"unparametrized shape: {original_shape}\n"
229                    f"parametrized shape: {Z.shape}"
230                )
231
232    def right_inverse(self, value: Tensor) -> None:
233        r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order.
234
235        Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor
236        or in ``self.original0``, ``self.original1``, ... if it outputs several.
237
238        Args:
239            value (Tensor): Value to which initialize the module
240        """
241        # All the exceptions in this function should almost never throw.
242        # They could throw if, for example, right_inverse function returns a different
243        # dtype when given a different input, which should most likely be caused by a
244        # bug in the user's code
245
246        with torch.no_grad():
247            # See https://github.com/pytorch/pytorch/issues/53103
248            for module in reversed(self):  # type: ignore[call-overload]
249                if hasattr(module, "right_inverse"):
250                    value = module.right_inverse(value)
251                else:
252                    raise RuntimeError(
253                        f"parametrization {type(module).__name__} does not implement "
254                        "right_inverse."
255                    )
256            if self.is_tensor:
257                # These exceptions should only throw when a right_inverse function does not
258                # return the same dtype for every input, which should most likely be caused by a bug
259                if not isinstance(value, Tensor):
260                    raise ValueError(
261                        f"`right_inverse` should return a tensor. Got {type(value).__name__}"
262                    )
263                if value.dtype != self.original.dtype:
264                    raise ValueError(
265                        f"The tensor returned by `right_inverse` has dtype {value.dtype} "
266                        f"while `original` has dtype {self.original.dtype}"
267                    )
268                # We know that the result is going to have the same dtype
269                _maybe_set(self.original, value)
270            else:
271                if not isinstance(value, collections.abc.Sequence):
272                    raise ValueError(
273                        "'right_inverse' must return a sequence of tensors. "
274                        f"Got {type(value).__name__}."
275                    )
276                if len(value) != self.ntensors:
277                    raise ValueError(
278                        "'right_inverse' must return a sequence of tensors of length "
279                        f"{self.ntensors}. Got a sequence of length {len(value)}."
280                    )
281                for i, tensor in enumerate(value):
282                    original_i = getattr(self, f"original{i}")
283                    if not isinstance(tensor, Tensor):
284                        raise ValueError(
285                            f"`right_inverse` must return a sequence of tensors. "
286                            f"Got element {i} of type {type(tensor).__name__}"
287                        )
288                    if original_i.dtype != tensor.dtype:
289                        raise ValueError(
290                            f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
291                            f"while `original{i}` has dtype {original_i.dtype}"
292                        )
293                    _maybe_set(original_i, tensor)
294
295    def forward(self) -> Tensor:
296        if torch.jit.is_scripting():
297            raise RuntimeError("Parametrization is not working with scripting.")
298        # Unpack the originals for the first parametrization
299        if self.is_tensor:
300            x = self[0](self.original)
301        else:
302            originals = (getattr(self, f"original{i}") for i in range(self.ntensors))
303            x = self[0](*originals)
304        # It's not possible to call self[1:] here, so we have to be a bit more cryptic
305        # Also we want to skip all non-integer keys
306        curr_idx = 1
307        while hasattr(self, str(curr_idx)):
308            x = self[curr_idx](x)
309            curr_idx += 1
310        return x
311
312
313def _inject_new_class(module: Module) -> None:
314    r"""Set up a module to be parametrized.
315
316    This works by substituting the class of the module by a class
317    that extends it to be able to inject a property
318
319    Args:
320        module (nn.Module): module into which to inject the property
321    """
322    cls = module.__class__
323
324    def default_deepcopy(self, memo):
325        # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
326        obj = memo.get(id(self), None)
327        if obj is not None:
328            return obj
329        replica = self.__new__(self.__class__)
330        memo[id(self)] = replica
331        replica.__dict__ = deepcopy(self.__dict__, memo)
332        # Also save all slots if they exist.
333        slots_to_save = copyreg._slotnames(self.__class__)  # type: ignore[attr-defined]
334        for slot in slots_to_save:
335            if hasattr(self, slot):
336                setattr(replica, slot, deepcopy(getattr(self, slot), memo))
337        return replica
338
339    def getstate(self):
340        raise RuntimeError(
341            "Serialization of parametrized modules is only "
342            "supported through state_dict(). See:\n"
343            "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
344            "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
345        )
346
347    dct = {"__getstate__": getstate}
348    # We don't allow serialization of parametrized modules but should still allow deepcopying.
349    # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
350    if not hasattr(cls, "__deepcopy__"):
351        dct["__deepcopy__"] = default_deepcopy  # type: ignore[assignment]
352
353    param_cls = type(
354        f"Parametrized{cls.__name__}",
355        (cls,),
356        dct,
357    )
358
359    module.__class__ = param_cls
360
361
362def _inject_property(module: Module, tensor_name: str) -> None:
363    r"""Injects a property into module[tensor_name].
364
365    It assumes that the class in the module has already been modified from its
366    original one using _inject_new_class and that the tensor under :attr:`tensor_name`
367    has already been moved out
368
369    Args:
370        module (nn.Module): module into which to inject the property
371        tensor_name (str): name of the name of the property to create
372    """
373    # We check the precondition.
374    # This should never fire if register_parametrization is correctly implemented
375    assert not hasattr(module, tensor_name)
376
377    @torch.jit.unused
378    def get_cached_parametrization(parametrization) -> Tensor:
379        global _cache
380        key = (id(module), tensor_name)
381        tensor = _cache.get(key)
382        if tensor is None:
383            tensor = parametrization()
384            _cache[key] = tensor
385        return tensor
386
387    def get_parametrized(self) -> Tensor:
388        if torch.jit.is_scripting():
389            raise RuntimeError("Parametrization is not working with scripting.")
390        parametrization = self.parametrizations[tensor_name]
391        if _cache_enabled:
392            if torch.jit.is_scripting():
393                # Scripting
394                raise RuntimeError(
395                    "Caching is not implemented for scripting. "
396                    "Either disable caching or avoid scripting."
397                )
398            elif torch._C._get_tracing_state() is not None:
399                # Tracing
400                raise RuntimeError(
401                    "Cannot trace a model while caching parametrizations."
402                )
403            else:
404                return get_cached_parametrization(parametrization)
405        else:
406            # If caching is not active, this function just evaluates the parametrization
407            return parametrization()
408
409    def set_original(self, value: Tensor) -> None:
410        if torch.jit.is_scripting():
411            raise RuntimeError("Parametrization is not working with scripting.")
412        self.parametrizations[tensor_name].right_inverse(value)
413
414    setattr(module.__class__, tensor_name, property(get_parametrized, set_original))
415
416
417def register_parametrization(
418    module: Module,
419    tensor_name: str,
420    parametrization: Module,
421    *,
422    unsafe: bool = False,
423) -> Module:
424    r"""Register a parametrization to a tensor in a module.
425
426    Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``,
427    the module will return the parametrized version ``parametrization(module.weight)``.
428    If the original tensor requires a gradient, the backward pass will differentiate
429    through :attr:`parametrization`, and the optimizer will update the tensor accordingly.
430
431    The first time that a module registers a parametrization, this function will add an attribute
432    ``parametrizations`` to the module of type :class:`~ParametrizationList`.
433
434    The list of parametrizations on the tensor ``weight`` will be accessible under
435    ``module.parametrizations.weight``.
436
437    The original tensor will be accessible under
438    ``module.parametrizations.weight.original``.
439
440    Parametrizations may be concatenated by registering several parametrizations
441    on the same attribute.
442
443    The training mode of a registered parametrization is updated on registration
444    to match the training mode of the host module
445
446    Parametrized parameters and buffers have an inbuilt caching system that can be activated
447    using the context manager :func:`cached`.
448
449    A :attr:`parametrization` may optionally implement a method with signature
450
451    .. code-block:: python
452
453        def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
454
455    This method is called on the unparametrized tensor when the first parametrization
456    is registered to compute the initial value of the original tensor.
457    If this method is not implemented, the original tensor will be just the unparametrized tensor.
458
459    If all the parametrizations registered on a tensor implement `right_inverse` it is possible
460    to initialize a parametrized tensor by assigning to it, as shown in the example below.
461
462    It is possible for the first parametrization to depend on several inputs.
463    This may be implemented returning a tuple of tensors from ``right_inverse``
464    (see the example implementation of a ``RankOne`` parametrization below).
465
466    In this case, the unconstrained tensors are also located under ``module.parametrizations.weight``
467    with names ``original0``, ``original1``,...
468
469    .. note::
470
471        If unsafe=False (default) both the forward and right_inverse methods will be called
472        once to perform a number of consistency checks.
473        If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
474        and nothing will be called otherwise.
475
476    .. note::
477
478        In most situations, ``right_inverse`` will be a function such that
479        ``forward(right_inverse(X)) == X`` (see
480        `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
481        Sometimes, when the parametrization is not surjective, it may be reasonable
482        to relax this.
483
484    .. warning::
485
486        If a parametrization depends on several inputs, :func:`~register_parametrization`
487        will register a number of new parameters. If such parametrization is registered
488        after the optimizer is created, these new parameters will need to be added manually
489        to the optimizer. See :meth:`torch.Optimizer.add_param_group`.
490
491    Args:
492        module (nn.Module): module on which to register the parametrization
493        tensor_name (str): name of the parameter or buffer on which to register
494            the parametrization
495        parametrization (nn.Module): the parametrization to register
496    Keyword args:
497        unsafe (bool): a boolean flag that denotes whether the parametrization
498            may change the dtype and shape of the tensor. Default: `False`
499            Warning: the parametrization is not checked for consistency upon registration.
500            Enable this flag at your own risk.
501
502    Raises:
503        ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name`
504
505    Examples:
506        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
507        >>> import torch
508        >>> import torch.nn as nn
509        >>> import torch.nn.utils.parametrize as P
510        >>>
511        >>> class Symmetric(nn.Module):
512        >>>     def forward(self, X):
513        >>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
514        >>>
515        >>>     def right_inverse(self, A):
516        >>>         return A.triu()
517        >>>
518        >>> m = nn.Linear(5, 5)
519        >>> P.register_parametrization(m, "weight", Symmetric())
520        >>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
521        True
522        >>> A = torch.rand(5, 5)
523        >>> A = A + A.T   # A is now symmetric
524        >>> m.weight = A  # Initialize the weight to be the symmetric matrix A
525        >>> print(torch.allclose(m.weight, A))
526        True
527
528        >>> class RankOne(nn.Module):
529        >>>     def forward(self, x, y):
530        >>>         # Form a rank 1 matrix multiplying two vectors
531        >>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
532        >>>
533        >>>     def right_inverse(self, Z):
534        >>>         # Project Z onto the rank 1 matrices
535        >>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
536        >>>         # Return rescaled singular vectors
537        >>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
538        >>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
539        >>>
540        >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
541        >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
542        1
543
544    """
545    parametrization.train(module.training)
546    if is_parametrized(module, tensor_name):
547        # Correctness checks.
548        # If A is the space of tensors with shape and dtype equal to module.weight
549        # we check that parametrization.forward and parametrization.right_inverse are
550        # functions from A to A
551        if not unsafe:
552            Y = getattr(module, tensor_name)
553            X = parametrization(Y)
554            if not isinstance(X, Tensor):
555                raise ValueError(
556                    f"A parametrization must return a tensor. Got {type(X).__name__}."
557                )
558            if X.dtype != Y.dtype:
559                raise ValueError(
560                    "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n"
561                    f"module.{tensor_name}.dtype: {Y.dtype}\n"
562                    f"parametrization(module.{tensor_name}).dtype: {X.dtype}"
563                )
564            if X.shape != Y.shape:
565                raise ValueError(
566                    "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n"
567                    f"module.{tensor_name}.shape: {Y.shape}\n"
568                    f"parametrization(module.{tensor_name}).shape: {X.shape}"
569                )
570            if hasattr(parametrization, "right_inverse"):
571                try:
572                    Z = parametrization.right_inverse(X)  # type: ignore[operator]
573                except NotImplementedError:
574                    pass
575                else:
576                    if not isinstance(Z, Tensor):
577                        raise ValueError(
578                            f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
579                        )
580                    if Z.dtype != Y.dtype:
581                        raise ValueError(
582                            "The tensor returned by parametrization.right_inverse must have the same dtype "
583                            f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
584                            f"module.{tensor_name}.dtype: {Y.dtype}\n"
585                            f"returned dtype: {Z.dtype}"
586                        )
587                    if Z.shape != Y.shape:
588                        raise ValueError(
589                            "The tensor returned by parametrization.right_inverse must have the same shape "
590                            f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
591                            f"module.{tensor_name}.shape: {Y.shape}\n"
592                            f"returned shape: {Z.shape}"
593                        )
594            # else right_inverse is assumed to be the identity
595
596        # add the new parametrization to the parametrization list
597        assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy
598        module.parametrizations[tensor_name].append(parametrization)
599        # If unsafe was True in previous parametrization, keep it enabled
600        module.parametrizations[tensor_name].unsafe |= unsafe  # type: ignore[index, union-attr]
601    elif tensor_name in module._buffers or tensor_name in module._parameters:
602        # Set the parametrization mechanism
603        # Fetch the original buffer or parameter
604        original = getattr(module, tensor_name)
605        # We create this early to check for possible errors
606        parametrizations = ParametrizationList(
607            [parametrization], original, unsafe=unsafe
608        )
609        # Delete the previous parameter or buffer
610        delattr(module, tensor_name)
611        # If this is the first parametrization registered on the module,
612        # we prepare the module to inject the property
613        if not is_parametrized(module):
614            # Change the class
615            _inject_new_class(module)
616            # Inject a ``ModuleDict`` into the instance under module.parametrizations
617            module.parametrizations = ModuleDict()
618        # Add a property into the class
619        _inject_property(module, tensor_name)
620        # Add a ParametrizationList
621        assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy
622        module.parametrizations[tensor_name] = parametrizations
623    else:
624        raise ValueError(
625            f"Module '{module}' does not have a parameter, a buffer, or a "
626            f"parametrized element with name '{tensor_name}'"
627        )
628    return module
629
630
631def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool:
632    r"""Determine if a module has a parametrization.
633
634    Args:
635        module (nn.Module): module to query
636        tensor_name (str, optional): name of the parameter in the module
637            Default: ``None``
638    Returns:
639        ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`,
640        or if it has any parametrization when :attr:`tensor_name` is ``None``;
641        otherwise ``False``
642    """
643    parametrizations = getattr(module, "parametrizations", None)
644    if parametrizations is None or not isinstance(parametrizations, ModuleDict):
645        return False
646    if tensor_name is None:
647        # Check that there is at least one parametrized buffer or Parameter
648        return len(parametrizations) > 0
649    else:
650        return tensor_name in parametrizations
651
652
653def remove_parametrizations(
654    module: Module,
655    tensor_name: str,
656    leave_parametrized: bool = True,
657) -> Module:
658    r"""Remove the parametrizations on a tensor in a module.
659
660    - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
661      its current output. In this case, the parametrization shall not change the ``dtype``
662      of the tensor.
663    - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
664      the unparametrised tensor in ``module.parametrizations[tensor_name].original``.
665      This is only possible when the parametrization depends on just one tensor.
666
667    Args:
668        module (nn.Module): module from which remove the parametrization
669        tensor_name (str): name of the parametrization to be removed
670        leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
671            Default: ``True``
672
673    Returns:
674        Module: module
675
676    Raises:
677        ValueError: if ``module[tensor_name]`` is not parametrized
678        ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors
679    """
680    if not is_parametrized(module, tensor_name):
681        raise ValueError(
682            f"Module {module} does not have a parametrization on {tensor_name}"
683        )
684
685    # Fetch the original tensor
686    assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy
687    parametrizations = module.parametrizations[tensor_name]
688    if parametrizations.is_tensor:
689        original = parametrizations.original
690        if leave_parametrized:
691            with torch.no_grad():
692                t = getattr(module, tensor_name)
693            # We know they have the same dtype because we have checked this when registering the
694            # parametrizations. As such, we can use set_
695            # We do this so that the parameter does not to change the id()
696            # This way the user does not need to update the optimizer
697            with torch.no_grad():
698                if type(original) is torch.Tensor:
699                    _maybe_set(original, t)
700                else:
701                    try:
702                        _maybe_set(original, t)
703                    except RuntimeError as e:
704                        # TODO: Fix this for tensor subclasses that are parameters:
705                        # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
706                        raise RuntimeError(
707                            "Calling remove_parametrizations() with leave_parametrized=True "
708                            "for a parameter that is an instance of a tensor subclass requires "
709                            "set_() to be implemented correctly for the tensor subclass."
710                            "Alternatively, one can opt into the swap_tensors path"
711                            "Either set leave_parametrized=False or provide a working implementation"
712                            "for set_() in the tensor subclass or set "
713                            "torch.__future__.set_swap_module_params_on_conversion(True)."
714                        ) from e
715    else:
716        if leave_parametrized:
717            # We cannot use no_grad because we need to know whether one or more
718            # original tensors required grad
719            t = getattr(module, tensor_name)
720            # We'll have to trust the user to add it to the optimizer
721            original = Parameter(t) if t.requires_grad else t
722        else:
723            raise ValueError(
724                "Cannot leave unparametrized (`leave_parametrized=False`) a tensor "
725                "that is parametrized in terms of a sequence of tensors."
726            )
727
728    # Delete the property that manages the parametrization
729    delattr(module.__class__, tensor_name)
730    # Delete the ParametrizationList
731    del module.parametrizations[tensor_name]
732
733    # Restore the parameter / buffer into the main class
734    _register_parameter_or_buffer(module, tensor_name, original)
735
736    # Roll back the parametrized class if no other buffer or parameter
737    # is currently parametrized in this class
738    if not is_parametrized(module):
739        delattr(module, "parametrizations")
740        # Restore class
741        orig_cls = module.__class__.__bases__[0]
742        module.__class__ = orig_cls
743    return module
744
745
746def type_before_parametrizations(module: Module) -> type:
747    r"""Return the module type before parametrizations were applied and if not, then it returns the module type.
748
749    Args:
750        module (nn.Module): module to get type of
751    """
752    if is_parametrized(module):
753        return module.__class__.__bases__[0]
754    else:
755        return type(module)
756
757
758def transfer_parametrizations_and_params(
759    from_module: Module,
760    to_module: Module,
761    tensor_name: Optional[str] = None,
762) -> Module:
763    r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`.
764
765    If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise
766    transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them.
767    Does nothing if from_module is not parametrized.
768
769    Args:
770        from_module (nn.Module): module to transfer from
771        to_module (nn.Module): module to transfer to
772        tensor_name (str, optional): parameter to transfer
773
774    Returns:
775        Module: to_module
776    """
777    if is_parametrized(from_module):
778        assert isinstance(from_module.parametrizations, ModuleDict)  # for mypy
779
780        # get list of all params or the single param to transfer
781        parameters_to_transfer: Union[list, ModuleDict] = (
782            from_module.parametrizations if tensor_name is None else [tensor_name]
783        )
784
785        assert hasattr(parameters_to_transfer, "__iter__")  # for mypy
786        for parameter_name in parameters_to_transfer:
787            # initialize the to-be-transferred param in to_module if it doesn't exist already
788            if not hasattr(to_module, parameter_name):
789                setattr(
790                    to_module,
791                    parameter_name,
792                    Parameter(getattr(from_module, parameter_name)),
793                )
794
795            # apply the params's parametrizations to to_module
796            for param_func in from_module.parametrizations[parameter_name]:
797                register_parametrization(to_module, parameter_name, param_func)
798            assert isinstance(to_module.parametrizations, ModuleDict)  # for mypy
799
800            # make values match, original values can be stored in either original or
801            # original0, original1..., need to check both cases
802            if hasattr(from_module.parametrizations[parameter_name], "original"):
803                to_module.parametrizations[
804                    parameter_name
805                ].original = from_module.parametrizations[parameter_name].original
806            else:
807                num = 0
808                orig_num = "original" + str(num)
809                # loop through each original# until all values have been set
810                while hasattr(from_module.parametrizations[parameter_name], orig_num):
811                    setattr(
812                        to_module.parametrizations[parameter_name],
813                        orig_num,
814                        getattr(from_module.parametrizations[parameter_name], orig_num),
815                    )
816                    num = num + 1
817                    orig_num = "original" + str(num)
818
819    return to_module
820