xref: /aosp_15_r20/external/pytorch/torch/nn/utils/parametrizations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from enum import auto, Enum
3from typing import Optional
4
5import torch
6import torch.nn.functional as F
7from torch import Tensor
8from torch.nn.modules import Module
9from torch.nn.utils import parametrize
10
11
12__all__ = ["orthogonal", "spectral_norm", "weight_norm"]
13
14
15def _is_orthogonal(Q, eps=None):
16    n, k = Q.size(-2), Q.size(-1)
17    Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
18    # A reasonable eps, but not too large
19    eps = 10.0 * n * torch.finfo(Q.dtype).eps
20    return torch.allclose(Q.mH @ Q, Id, atol=eps)
21
22
23def _make_orthogonal(A):
24    """Assume that A is a tall matrix.
25
26    Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative.
27    """
28    X, tau = torch.geqrf(A)
29    Q = torch.linalg.householder_product(X, tau)
30    # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
31    Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
32    return Q
33
34
35class _OrthMaps(Enum):
36    matrix_exp = auto()
37    cayley = auto()
38    householder = auto()
39
40
41class _Orthogonal(Module):
42    base: Tensor
43
44    def __init__(
45        self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True
46    ) -> None:
47        super().__init__()
48
49        # Note [Householder complex]
50        # For complex tensors, it is not possible to compute the tensor `tau` necessary for
51        # linalg.householder_product from the reflectors.
52        # To see this, note that the reflectors have a shape like:
53        # 0 0 0
54        # * 0 0
55        # * * 0
56        # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
57        # to parametrize the unitary matrices. Saving tau on its own does not work either, because
58        # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
59        # them as independent tensors we would not maintain the constraint
60        # An equivalent reasoning holds for rectangular matrices
61        if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
62            raise ValueError(
63                "The householder parametrization does not support complex tensors."
64            )
65
66        self.shape = weight.shape
67        self.orthogonal_map = orthogonal_map
68        if use_trivialization:
69            self.register_buffer("base", None)
70
71    def forward(self, X: torch.Tensor) -> torch.Tensor:
72        n, k = X.size(-2), X.size(-1)
73        transposed = n < k
74        if transposed:
75            X = X.mT
76            n, k = k, n
77        # Here n > k and X is a tall matrix
78        if (
79            self.orthogonal_map == _OrthMaps.matrix_exp
80            or self.orthogonal_map == _OrthMaps.cayley
81        ):
82            # We just need n x k - k(k-1)/2 parameters
83            X = X.tril()
84            if n != k:
85                # Embed into a square matrix
86                X = torch.cat(
87                    [X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1
88                )
89            A = X - X.mH
90            # A is skew-symmetric (or skew-hermitian)
91            if self.orthogonal_map == _OrthMaps.matrix_exp:
92                Q = torch.matrix_exp(A)
93            elif self.orthogonal_map == _OrthMaps.cayley:
94                # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
95                Id = torch.eye(n, dtype=A.dtype, device=A.device)
96                Q = torch.linalg.solve(
97                    torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5)
98                )
99            # Q is now orthogonal (or unitary) of size (..., n, n)
100            if n != k:
101                Q = Q[..., :k]
102            # Q is now the size of the X (albeit perhaps transposed)
103        else:
104            # X is real here, as we do not support householder with complex numbers
105            A = X.tril(diagonal=-1)
106            tau = 2.0 / (1.0 + (A * A).sum(dim=-2))
107            Q = torch.linalg.householder_product(A, tau)
108            # The diagonal of X is 1's and -1's
109            # We do not want to differentiate through this or update the diagonal of X hence the casting
110            Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
111
112        if hasattr(self, "base"):
113            Q = self.base @ Q
114        if transposed:
115            Q = Q.mT
116        return Q  # type: ignore[possibly-undefined]
117
118    @torch.autograd.no_grad()
119    def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
120        if Q.shape != self.shape:
121            raise ValueError(
122                f"Expected a matrix or batch of matrices of shape {self.shape}. "
123                f"Got a tensor of shape {Q.shape}."
124            )
125
126        Q_init = Q
127        n, k = Q.size(-2), Q.size(-1)
128        transpose = n < k
129        if transpose:
130            Q = Q.mT
131            n, k = k, n
132
133        # We always make sure to always copy Q in every path
134        if not hasattr(self, "base"):
135            # Note [right_inverse expm cayley]
136            # If we do not have use_trivialization=True, we just implement the inverse of the forward
137            # map for the Householder. To see why, think that for the Cayley map,
138            # we would need to find the matrix X \in R^{n x k} such that:
139            # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
140            # A = Y - Y.mH
141            # cayley(A)[:, :k]
142            # gives the original tensor. It is not clear how to do this.
143            # Perhaps via some algebraic manipulation involving the QR like that of
144            # Corollary 2.2 in Edelman, Arias and Smith?
145            if (
146                self.orthogonal_map == _OrthMaps.cayley
147                or self.orthogonal_map == _OrthMaps.matrix_exp
148            ):
149                raise NotImplementedError(
150                    "It is not possible to assign to the matrix exponential "
151                    "or the Cayley parametrizations when use_trivialization=False."
152                )
153
154            # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
155            # Here Q is always real because we do not support householder and complex matrices.
156            # See note [Householder complex]
157            A, tau = torch.geqrf(Q)
158            # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
159            # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
160            # The diagonal of Q is the diagonal of R from the qr decomposition
161            A.diagonal(dim1=-2, dim2=-1).sign_()
162            # Equality with zero is ok because LAPACK returns exactly zero when it does not want
163            # to use a particular reflection
164            A.diagonal(dim1=-2, dim2=-1)[tau == 0.0] *= -1
165            return A.mT if transpose else A
166        else:
167            if n == k:
168                # We check whether Q is orthogonal
169                if not _is_orthogonal(Q):
170                    Q = _make_orthogonal(Q)
171                else:  # Is orthogonal
172                    Q = Q.clone()
173            else:
174                # Complete Q into a full n x n orthogonal matrix
175                N = torch.randn(
176                    *(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device
177                )
178                Q = torch.cat([Q, N], dim=-1)
179                Q = _make_orthogonal(Q)
180            self.base = Q
181
182            # It is necessary to return the -Id, as we use the diagonal for the
183            # Householder parametrization. Using -Id makes:
184            # householder(torch.zeros(m,n)) == torch.eye(m,n)
185            # Poor man's version of eye_like
186            neg_Id = torch.zeros_like(Q_init)
187            neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.0)
188            return neg_Id
189
190
191def orthogonal(
192    module: Module,
193    name: str = "weight",
194    orthogonal_map: Optional[str] = None,
195    *,
196    use_trivialization: bool = True,
197) -> Module:
198    r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices.
199
200    Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
201    matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
202
203    .. math::
204
205        \begin{align*}
206            Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
207            QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
208        \end{align*}
209
210    where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
211    and the transpose when :math:`Q` is real-valued, and
212    :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
213    In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
214    and orthonormal rows otherwise.
215
216    If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
217
218    The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
219
220    - ``"matrix_exp"``/``"cayley"``:
221      the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
222      :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
223      :math:`A` to give an orthogonal matrix.
224    - ``"householder"``: computes a product of Householder reflectors
225      (:func:`~torch.linalg.householder_product`).
226
227    ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
228    ``"householder"``, but they are slower to compute for very thin or very wide matrices.
229
230    If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
231    where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
232    ``module.parametrizations.weight[0].base``. This helps the
233    convergence of the parametrized layer at the expense of some extra memory use.
234    See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
235
236    Initial value of :math:`Q`:
237    If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
238    of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
239    and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
240    Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
241    Otherwise, the initial value is the result of the composition of all the registered
242    parametrizations applied to the original tensor.
243
244    .. note::
245        This function is implemented using the parametrization functionality
246        in :func:`~torch.nn.utils.parametrize.register_parametrization`.
247
248
249    .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
250    .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
251
252    Args:
253        module (nn.Module): module on which to register the parametrization.
254        name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
255        orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
256            Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
257        use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
258            Default: ``True``.
259
260    Returns:
261        The original module with an orthogonal parametrization registered to the specified
262        weight
263
264    Example::
265
266        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
267        >>> orth_linear = orthogonal(nn.Linear(20, 40))
268        >>> orth_linear
269        ParametrizedLinear(
270        in_features=20, out_features=40, bias=True
271        (parametrizations): ModuleDict(
272            (weight): ParametrizationList(
273            (0): _Orthogonal()
274            )
275        )
276        )
277        >>> # xdoctest: +IGNORE_WANT
278        >>> Q = orth_linear.weight
279        >>> torch.dist(Q.T @ Q, torch.eye(20))
280        tensor(4.9332e-07)
281    """
282    weight = getattr(module, name, None)
283    if not isinstance(weight, Tensor):
284        raise ValueError(
285            f"Module '{module}' has no parameter or buffer with name '{name}'"
286        )
287
288    # We could implement this for 1-dim tensors as the maps on the sphere
289    # but I believe it'd bite more people than it'd help
290    if weight.ndim < 2:
291        raise ValueError(
292            "Expected a matrix or batch of matrices. "
293            f"Got a tensor of {weight.ndim} dimensions."
294        )
295
296    if orthogonal_map is None:
297        orthogonal_map = (
298            "matrix_exp"
299            if weight.size(-2) == weight.size(-1) or weight.is_complex()
300            else "householder"
301        )
302
303    orth_enum = getattr(_OrthMaps, orthogonal_map, None)
304    if orth_enum is None:
305        raise ValueError(
306            'orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
307            f"Got: {orthogonal_map}"
308        )
309    orth = _Orthogonal(weight, orth_enum, use_trivialization=use_trivialization)
310    parametrize.register_parametrization(module, name, orth, unsafe=True)
311    return module
312
313
314class _WeightNorm(Module):
315    def __init__(
316        self,
317        dim: Optional[int] = 0,
318    ) -> None:
319        super().__init__()
320        if dim is None:
321            dim = -1
322        self.dim = dim
323
324    def forward(self, weight_g, weight_v):
325        return torch._weight_norm(weight_v, weight_g, self.dim)
326
327    def right_inverse(self, weight):
328        weight_g = torch.norm_except_dim(weight, 2, self.dim)
329        weight_v = weight
330
331        return weight_g, weight_v
332
333
334def weight_norm(module: Module, name: str = "weight", dim: int = 0):
335    r"""Apply weight normalization to a parameter in the given module.
336
337    .. math::
338         \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
339
340    Weight normalization is a reparameterization that decouples the magnitude
341    of a weight tensor from its direction. This replaces the parameter specified
342    by :attr:`name` with two parameters: one specifying the magnitude
343    and one specifying the direction.
344
345    By default, with ``dim=0``, the norm is computed independently per output
346    channel/plane. To compute a norm over the entire weight tensor, use
347    ``dim=None``.
348
349    See https://arxiv.org/abs/1602.07868
350
351    Args:
352        module (Module): containing module
353        name (str, optional): name of weight parameter
354        dim (int, optional): dimension over which to compute the norm
355
356    Returns:
357        The original module with the weight norm hook
358
359    Example::
360
361        >>> m = weight_norm(nn.Linear(20, 40), name='weight')
362        >>> m
363        ParametrizedLinear(
364          in_features=20, out_features=40, bias=True
365          (parametrizations): ModuleDict(
366            (weight): ParametrizationList(
367              (0): _WeightNorm()
368            )
369          )
370        )
371        >>> m.parametrizations.weight.original0.size()
372        torch.Size([40, 1])
373        >>> m.parametrizations.weight.original1.size()
374        torch.Size([40, 20])
375
376    """
377    _weight_norm = _WeightNorm(dim)
378    parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
379
380    def _weight_norm_compat_hook(
381        state_dict,
382        prefix,
383        local_metadata,
384        strict,
385        missing_keys,
386        unexpected_keys,
387        error_msgs,
388    ):
389        g_key = f"{prefix}{name}_g"
390        v_key = f"{prefix}{name}_v"
391        if g_key in state_dict and v_key in state_dict:
392            original0 = state_dict.pop(g_key)
393            original1 = state_dict.pop(v_key)
394            state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
395            state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
396
397    module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
398    return module
399
400
401class _SpectralNorm(Module):
402    def __init__(
403        self,
404        weight: torch.Tensor,
405        n_power_iterations: int = 1,
406        dim: int = 0,
407        eps: float = 1e-12,
408    ) -> None:
409        super().__init__()
410        ndim = weight.ndim
411        if dim >= ndim or dim < -ndim:
412            raise IndexError(
413                "Dimension out of range (expected to be in range of "
414                f"[-{ndim}, {ndim - 1}] but got {dim})"
415            )
416
417        if n_power_iterations <= 0:
418            raise ValueError(
419                "Expected n_power_iterations to be positive, but "
420                f"got n_power_iterations={n_power_iterations}"
421            )
422        self.dim = dim if dim >= 0 else dim + ndim
423        self.eps = eps
424        if ndim > 1:
425            # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
426            self.n_power_iterations = n_power_iterations
427            weight_mat = self._reshape_weight_to_matrix(weight)
428            h, w = weight_mat.size()
429
430            u = weight_mat.new_empty(h).normal_(0, 1)
431            v = weight_mat.new_empty(w).normal_(0, 1)
432            self.register_buffer("_u", F.normalize(u, dim=0, eps=self.eps))
433            self.register_buffer("_v", F.normalize(v, dim=0, eps=self.eps))
434
435            # Start with u, v initialized to some reasonable values by performing a number
436            # of iterations of the power method
437            self._power_method(weight_mat, 15)
438
439    def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
440        # Precondition
441        assert weight.ndim > 1
442
443        if self.dim != 0:
444            # permute dim to front
445            weight = weight.permute(
446                self.dim, *(d for d in range(weight.dim()) if d != self.dim)
447            )
448
449        return weight.flatten(1)
450
451    @torch.autograd.no_grad()
452    def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
453        # See original note at torch/nn/utils/spectral_norm.py
454        # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
455        #     updated in power iteration **in-place**. This is very important
456        #     because in `DataParallel` forward, the vectors (being buffers) are
457        #     broadcast from the parallelized module to each module replica,
458        #     which is a new module object created on the fly. And each replica
459        #     runs its own spectral norm power iteration. So simply assigning
460        #     the updated vectors to the module this function runs on will cause
461        #     the update to be lost forever. And the next time the parallelized
462        #     module is replicated, the same randomly initialized vectors are
463        #     broadcast and used!
464        #
465        #     Therefore, to make the change propagate back, we rely on two
466        #     important behaviors (also enforced via tests):
467        #       1. `DataParallel` doesn't clone storage if the broadcast tensor
468        #          is already on correct device; and it makes sure that the
469        #          parallelized module is already on `device[0]`.
470        #       2. If the out tensor in `out=` kwarg has correct shape, it will
471        #          just fill in the values.
472        #     Therefore, since the same power iteration is performed on all
473        #     devices, simply updating the tensors in-place will make sure that
474        #     the module replica on `device[0]` will update the _u vector on the
475        #     parallelized module (by shared storage).
476        #
477        #    However, after we update `u` and `v` in-place, we need to **clone**
478        #    them before using them to normalize the weight. This is to support
479        #    backproping through two forward passes, e.g., the common pattern in
480        #    GAN training: loss = D(real) - D(fake). Otherwise, engine will
481        #    complain that variables needed to do backward for the first forward
482        #    (i.e., the `u` and `v` vectors) are changed in the second forward.
483
484        # Precondition
485        assert weight_mat.ndim > 1
486
487        for _ in range(n_power_iterations):
488            # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
489            # are the first left and right singular vectors.
490            # This power iteration produces approximations of `u` and `v`.
491            self._u = F.normalize(
492                torch.mv(weight_mat, self._v),  # type: ignore[has-type]
493                dim=0,
494                eps=self.eps,
495                out=self._u,  # type: ignore[has-type]
496            )
497            self._v = F.normalize(
498                torch.mv(weight_mat.H, self._u),  # type: ignore[has-type]
499                dim=0,
500                eps=self.eps,
501                out=self._v,  # type: ignore[has-type]
502            )
503
504    def forward(self, weight: torch.Tensor) -> torch.Tensor:
505        if weight.ndim == 1:
506            # Faster and more exact path, no need to approximate anything
507            return F.normalize(weight, dim=0, eps=self.eps)
508        else:
509            weight_mat = self._reshape_weight_to_matrix(weight)
510            if self.training:
511                self._power_method(weight_mat, self.n_power_iterations)
512            # See above on why we need to clone
513            u = self._u.clone(memory_format=torch.contiguous_format)
514            v = self._v.clone(memory_format=torch.contiguous_format)
515            # The proper way of computing this should be through F.bilinear, but
516            # it seems to have some efficiency issues:
517            # https://github.com/pytorch/pytorch/issues/58093
518            sigma = torch.vdot(u, torch.mv(weight_mat, v))
519            return weight / sigma
520
521    def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
522        # we may want to assert here that the passed value already
523        # satisfies constraints
524        return value
525
526
527def spectral_norm(
528    module: Module,
529    name: str = "weight",
530    n_power_iterations: int = 1,
531    eps: float = 1e-12,
532    dim: Optional[int] = None,
533) -> Module:
534    r"""Apply spectral normalization to a parameter in the given module.
535
536    .. math::
537        \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
538        \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
539
540    When applied on a vector, it simplifies to
541
542    .. math::
543        \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
544
545    Spectral normalization stabilizes the training of discriminators (critics)
546    in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
547    of the model. :math:`\sigma` is approximated performing one iteration of the
548    `power method`_ every time the weight is accessed. If the dimension of the
549    weight tensor is greater than 2, it is reshaped to 2D in power iteration
550    method to get spectral norm.
551
552
553    See `Spectral Normalization for Generative Adversarial Networks`_ .
554
555    .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
556    .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
557
558    .. note::
559        This function is implemented using the parametrization functionality
560        in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
561        reimplementation of :func:`torch.nn.utils.spectral_norm`.
562
563    .. note::
564        When this constraint is registered, the singular vectors associated to the largest
565        singular value are estimated rather than sampled at random. These are then updated
566        performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
567        is accessed with the module on `training` mode.
568
569    .. note::
570        If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
571        is in training mode on removal, it will perform another power iteration.
572        If you'd like to avoid this iteration, set the module to eval mode
573        before its removal.
574
575    Args:
576        module (nn.Module): containing module
577        name (str, optional): name of weight parameter. Default: ``"weight"``.
578        n_power_iterations (int, optional): number of power iterations to
579            calculate spectral norm. Default: ``1``.
580        eps (float, optional): epsilon for numerical stability in
581            calculating norms. Default: ``1e-12``.
582        dim (int, optional): dimension corresponding to number of outputs.
583            Default: ``0``, except for modules that are instances of
584            ConvTranspose{1,2,3}d, when it is ``1``
585
586    Returns:
587        The original module with a new parametrization registered to the specified
588        weight
589
590    Example::
591
592        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
593        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
594        >>> snm = spectral_norm(nn.Linear(20, 40))
595        >>> snm
596        ParametrizedLinear(
597          in_features=20, out_features=40, bias=True
598          (parametrizations): ModuleDict(
599            (weight): ParametrizationList(
600              (0): _SpectralNorm()
601            )
602          )
603        )
604        >>> torch.linalg.matrix_norm(snm.weight, 2)
605        tensor(1.0081, grad_fn=<AmaxBackward0>)
606    """
607    weight = getattr(module, name, None)
608    if not isinstance(weight, Tensor):
609        raise ValueError(
610            f"Module '{module}' has no parameter or buffer with name '{name}'"
611        )
612
613    if dim is None:
614        if isinstance(
615            module,
616            (
617                torch.nn.ConvTranspose1d,
618                torch.nn.ConvTranspose2d,
619                torch.nn.ConvTranspose3d,
620            ),
621        ):
622            dim = 1
623        else:
624            dim = 0
625    parametrize.register_parametrization(
626        module, name, _SpectralNorm(weight, n_power_iterations, dim, eps)
627    )
628    return module
629