1# mypy: allow-untyped-defs 2from enum import auto, Enum 3from typing import Optional 4 5import torch 6import torch.nn.functional as F 7from torch import Tensor 8from torch.nn.modules import Module 9from torch.nn.utils import parametrize 10 11 12__all__ = ["orthogonal", "spectral_norm", "weight_norm"] 13 14 15def _is_orthogonal(Q, eps=None): 16 n, k = Q.size(-2), Q.size(-1) 17 Id = torch.eye(k, dtype=Q.dtype, device=Q.device) 18 # A reasonable eps, but not too large 19 eps = 10.0 * n * torch.finfo(Q.dtype).eps 20 return torch.allclose(Q.mH @ Q, Id, atol=eps) 21 22 23def _make_orthogonal(A): 24 """Assume that A is a tall matrix. 25 26 Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative. 27 """ 28 X, tau = torch.geqrf(A) 29 Q = torch.linalg.householder_product(X, tau) 30 # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs 31 Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) 32 return Q 33 34 35class _OrthMaps(Enum): 36 matrix_exp = auto() 37 cayley = auto() 38 householder = auto() 39 40 41class _Orthogonal(Module): 42 base: Tensor 43 44 def __init__( 45 self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True 46 ) -> None: 47 super().__init__() 48 49 # Note [Householder complex] 50 # For complex tensors, it is not possible to compute the tensor `tau` necessary for 51 # linalg.householder_product from the reflectors. 52 # To see this, note that the reflectors have a shape like: 53 # 0 0 0 54 # * 0 0 55 # * * 0 56 # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters 57 # to parametrize the unitary matrices. Saving tau on its own does not work either, because 58 # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise 59 # them as independent tensors we would not maintain the constraint 60 # An equivalent reasoning holds for rectangular matrices 61 if weight.is_complex() and orthogonal_map == _OrthMaps.householder: 62 raise ValueError( 63 "The householder parametrization does not support complex tensors." 64 ) 65 66 self.shape = weight.shape 67 self.orthogonal_map = orthogonal_map 68 if use_trivialization: 69 self.register_buffer("base", None) 70 71 def forward(self, X: torch.Tensor) -> torch.Tensor: 72 n, k = X.size(-2), X.size(-1) 73 transposed = n < k 74 if transposed: 75 X = X.mT 76 n, k = k, n 77 # Here n > k and X is a tall matrix 78 if ( 79 self.orthogonal_map == _OrthMaps.matrix_exp 80 or self.orthogonal_map == _OrthMaps.cayley 81 ): 82 # We just need n x k - k(k-1)/2 parameters 83 X = X.tril() 84 if n != k: 85 # Embed into a square matrix 86 X = torch.cat( 87 [X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1 88 ) 89 A = X - X.mH 90 # A is skew-symmetric (or skew-hermitian) 91 if self.orthogonal_map == _OrthMaps.matrix_exp: 92 Q = torch.matrix_exp(A) 93 elif self.orthogonal_map == _OrthMaps.cayley: 94 # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} 95 Id = torch.eye(n, dtype=A.dtype, device=A.device) 96 Q = torch.linalg.solve( 97 torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5) 98 ) 99 # Q is now orthogonal (or unitary) of size (..., n, n) 100 if n != k: 101 Q = Q[..., :k] 102 # Q is now the size of the X (albeit perhaps transposed) 103 else: 104 # X is real here, as we do not support householder with complex numbers 105 A = X.tril(diagonal=-1) 106 tau = 2.0 / (1.0 + (A * A).sum(dim=-2)) 107 Q = torch.linalg.householder_product(A, tau) 108 # The diagonal of X is 1's and -1's 109 # We do not want to differentiate through this or update the diagonal of X hence the casting 110 Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) 111 112 if hasattr(self, "base"): 113 Q = self.base @ Q 114 if transposed: 115 Q = Q.mT 116 return Q # type: ignore[possibly-undefined] 117 118 @torch.autograd.no_grad() 119 def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: 120 if Q.shape != self.shape: 121 raise ValueError( 122 f"Expected a matrix or batch of matrices of shape {self.shape}. " 123 f"Got a tensor of shape {Q.shape}." 124 ) 125 126 Q_init = Q 127 n, k = Q.size(-2), Q.size(-1) 128 transpose = n < k 129 if transpose: 130 Q = Q.mT 131 n, k = k, n 132 133 # We always make sure to always copy Q in every path 134 if not hasattr(self, "base"): 135 # Note [right_inverse expm cayley] 136 # If we do not have use_trivialization=True, we just implement the inverse of the forward 137 # map for the Householder. To see why, think that for the Cayley map, 138 # we would need to find the matrix X \in R^{n x k} such that: 139 # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) 140 # A = Y - Y.mH 141 # cayley(A)[:, :k] 142 # gives the original tensor. It is not clear how to do this. 143 # Perhaps via some algebraic manipulation involving the QR like that of 144 # Corollary 2.2 in Edelman, Arias and Smith? 145 if ( 146 self.orthogonal_map == _OrthMaps.cayley 147 or self.orthogonal_map == _OrthMaps.matrix_exp 148 ): 149 raise NotImplementedError( 150 "It is not possible to assign to the matrix exponential " 151 "or the Cayley parametrizations when use_trivialization=False." 152 ) 153 154 # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. 155 # Here Q is always real because we do not support householder and complex matrices. 156 # See note [Householder complex] 157 A, tau = torch.geqrf(Q) 158 # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could 159 # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition 160 # The diagonal of Q is the diagonal of R from the qr decomposition 161 A.diagonal(dim1=-2, dim2=-1).sign_() 162 # Equality with zero is ok because LAPACK returns exactly zero when it does not want 163 # to use a particular reflection 164 A.diagonal(dim1=-2, dim2=-1)[tau == 0.0] *= -1 165 return A.mT if transpose else A 166 else: 167 if n == k: 168 # We check whether Q is orthogonal 169 if not _is_orthogonal(Q): 170 Q = _make_orthogonal(Q) 171 else: # Is orthogonal 172 Q = Q.clone() 173 else: 174 # Complete Q into a full n x n orthogonal matrix 175 N = torch.randn( 176 *(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device 177 ) 178 Q = torch.cat([Q, N], dim=-1) 179 Q = _make_orthogonal(Q) 180 self.base = Q 181 182 # It is necessary to return the -Id, as we use the diagonal for the 183 # Householder parametrization. Using -Id makes: 184 # householder(torch.zeros(m,n)) == torch.eye(m,n) 185 # Poor man's version of eye_like 186 neg_Id = torch.zeros_like(Q_init) 187 neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.0) 188 return neg_Id 189 190 191def orthogonal( 192 module: Module, 193 name: str = "weight", 194 orthogonal_map: Optional[str] = None, 195 *, 196 use_trivialization: bool = True, 197) -> Module: 198 r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices. 199 200 Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized 201 matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as 202 203 .. math:: 204 205 \begin{align*} 206 Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ 207 QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} 208 \end{align*} 209 210 where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex 211 and the transpose when :math:`Q` is real-valued, and 212 :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. 213 In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` 214 and orthonormal rows otherwise. 215 216 If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. 217 218 The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: 219 220 - ``"matrix_exp"``/``"cayley"``: 221 the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ 222 :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric 223 :math:`A` to give an orthogonal matrix. 224 - ``"householder"``: computes a product of Householder reflectors 225 (:func:`~torch.linalg.householder_product`). 226 227 ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than 228 ``"householder"``, but they are slower to compute for very thin or very wide matrices. 229 230 If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", 231 where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under 232 ``module.parametrizations.weight[0].base``. This helps the 233 convergence of the parametrized layer at the expense of some extra memory use. 234 See `Trivializations for Gradient-Based Optimization on Manifolds`_ . 235 236 Initial value of :math:`Q`: 237 If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value 238 of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) 239 and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). 240 Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. 241 Otherwise, the initial value is the result of the composition of all the registered 242 parametrizations applied to the original tensor. 243 244 .. note:: 245 This function is implemented using the parametrization functionality 246 in :func:`~torch.nn.utils.parametrize.register_parametrization`. 247 248 249 .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map 250 .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 251 252 Args: 253 module (nn.Module): module on which to register the parametrization. 254 name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. 255 orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. 256 Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. 257 use_trivialization (bool, optional): whether to use the dynamic trivialization framework. 258 Default: ``True``. 259 260 Returns: 261 The original module with an orthogonal parametrization registered to the specified 262 weight 263 264 Example:: 265 266 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 267 >>> orth_linear = orthogonal(nn.Linear(20, 40)) 268 >>> orth_linear 269 ParametrizedLinear( 270 in_features=20, out_features=40, bias=True 271 (parametrizations): ModuleDict( 272 (weight): ParametrizationList( 273 (0): _Orthogonal() 274 ) 275 ) 276 ) 277 >>> # xdoctest: +IGNORE_WANT 278 >>> Q = orth_linear.weight 279 >>> torch.dist(Q.T @ Q, torch.eye(20)) 280 tensor(4.9332e-07) 281 """ 282 weight = getattr(module, name, None) 283 if not isinstance(weight, Tensor): 284 raise ValueError( 285 f"Module '{module}' has no parameter or buffer with name '{name}'" 286 ) 287 288 # We could implement this for 1-dim tensors as the maps on the sphere 289 # but I believe it'd bite more people than it'd help 290 if weight.ndim < 2: 291 raise ValueError( 292 "Expected a matrix or batch of matrices. " 293 f"Got a tensor of {weight.ndim} dimensions." 294 ) 295 296 if orthogonal_map is None: 297 orthogonal_map = ( 298 "matrix_exp" 299 if weight.size(-2) == weight.size(-1) or weight.is_complex() 300 else "householder" 301 ) 302 303 orth_enum = getattr(_OrthMaps, orthogonal_map, None) 304 if orth_enum is None: 305 raise ValueError( 306 'orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' 307 f"Got: {orthogonal_map}" 308 ) 309 orth = _Orthogonal(weight, orth_enum, use_trivialization=use_trivialization) 310 parametrize.register_parametrization(module, name, orth, unsafe=True) 311 return module 312 313 314class _WeightNorm(Module): 315 def __init__( 316 self, 317 dim: Optional[int] = 0, 318 ) -> None: 319 super().__init__() 320 if dim is None: 321 dim = -1 322 self.dim = dim 323 324 def forward(self, weight_g, weight_v): 325 return torch._weight_norm(weight_v, weight_g, self.dim) 326 327 def right_inverse(self, weight): 328 weight_g = torch.norm_except_dim(weight, 2, self.dim) 329 weight_v = weight 330 331 return weight_g, weight_v 332 333 334def weight_norm(module: Module, name: str = "weight", dim: int = 0): 335 r"""Apply weight normalization to a parameter in the given module. 336 337 .. math:: 338 \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 339 340 Weight normalization is a reparameterization that decouples the magnitude 341 of a weight tensor from its direction. This replaces the parameter specified 342 by :attr:`name` with two parameters: one specifying the magnitude 343 and one specifying the direction. 344 345 By default, with ``dim=0``, the norm is computed independently per output 346 channel/plane. To compute a norm over the entire weight tensor, use 347 ``dim=None``. 348 349 See https://arxiv.org/abs/1602.07868 350 351 Args: 352 module (Module): containing module 353 name (str, optional): name of weight parameter 354 dim (int, optional): dimension over which to compute the norm 355 356 Returns: 357 The original module with the weight norm hook 358 359 Example:: 360 361 >>> m = weight_norm(nn.Linear(20, 40), name='weight') 362 >>> m 363 ParametrizedLinear( 364 in_features=20, out_features=40, bias=True 365 (parametrizations): ModuleDict( 366 (weight): ParametrizationList( 367 (0): _WeightNorm() 368 ) 369 ) 370 ) 371 >>> m.parametrizations.weight.original0.size() 372 torch.Size([40, 1]) 373 >>> m.parametrizations.weight.original1.size() 374 torch.Size([40, 20]) 375 376 """ 377 _weight_norm = _WeightNorm(dim) 378 parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) 379 380 def _weight_norm_compat_hook( 381 state_dict, 382 prefix, 383 local_metadata, 384 strict, 385 missing_keys, 386 unexpected_keys, 387 error_msgs, 388 ): 389 g_key = f"{prefix}{name}_g" 390 v_key = f"{prefix}{name}_v" 391 if g_key in state_dict and v_key in state_dict: 392 original0 = state_dict.pop(g_key) 393 original1 = state_dict.pop(v_key) 394 state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 395 state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 396 397 module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) 398 return module 399 400 401class _SpectralNorm(Module): 402 def __init__( 403 self, 404 weight: torch.Tensor, 405 n_power_iterations: int = 1, 406 dim: int = 0, 407 eps: float = 1e-12, 408 ) -> None: 409 super().__init__() 410 ndim = weight.ndim 411 if dim >= ndim or dim < -ndim: 412 raise IndexError( 413 "Dimension out of range (expected to be in range of " 414 f"[-{ndim}, {ndim - 1}] but got {dim})" 415 ) 416 417 if n_power_iterations <= 0: 418 raise ValueError( 419 "Expected n_power_iterations to be positive, but " 420 f"got n_power_iterations={n_power_iterations}" 421 ) 422 self.dim = dim if dim >= 0 else dim + ndim 423 self.eps = eps 424 if ndim > 1: 425 # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) 426 self.n_power_iterations = n_power_iterations 427 weight_mat = self._reshape_weight_to_matrix(weight) 428 h, w = weight_mat.size() 429 430 u = weight_mat.new_empty(h).normal_(0, 1) 431 v = weight_mat.new_empty(w).normal_(0, 1) 432 self.register_buffer("_u", F.normalize(u, dim=0, eps=self.eps)) 433 self.register_buffer("_v", F.normalize(v, dim=0, eps=self.eps)) 434 435 # Start with u, v initialized to some reasonable values by performing a number 436 # of iterations of the power method 437 self._power_method(weight_mat, 15) 438 439 def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: 440 # Precondition 441 assert weight.ndim > 1 442 443 if self.dim != 0: 444 # permute dim to front 445 weight = weight.permute( 446 self.dim, *(d for d in range(weight.dim()) if d != self.dim) 447 ) 448 449 return weight.flatten(1) 450 451 @torch.autograd.no_grad() 452 def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None: 453 # See original note at torch/nn/utils/spectral_norm.py 454 # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 455 # updated in power iteration **in-place**. This is very important 456 # because in `DataParallel` forward, the vectors (being buffers) are 457 # broadcast from the parallelized module to each module replica, 458 # which is a new module object created on the fly. And each replica 459 # runs its own spectral norm power iteration. So simply assigning 460 # the updated vectors to the module this function runs on will cause 461 # the update to be lost forever. And the next time the parallelized 462 # module is replicated, the same randomly initialized vectors are 463 # broadcast and used! 464 # 465 # Therefore, to make the change propagate back, we rely on two 466 # important behaviors (also enforced via tests): 467 # 1. `DataParallel` doesn't clone storage if the broadcast tensor 468 # is already on correct device; and it makes sure that the 469 # parallelized module is already on `device[0]`. 470 # 2. If the out tensor in `out=` kwarg has correct shape, it will 471 # just fill in the values. 472 # Therefore, since the same power iteration is performed on all 473 # devices, simply updating the tensors in-place will make sure that 474 # the module replica on `device[0]` will update the _u vector on the 475 # parallelized module (by shared storage). 476 # 477 # However, after we update `u` and `v` in-place, we need to **clone** 478 # them before using them to normalize the weight. This is to support 479 # backproping through two forward passes, e.g., the common pattern in 480 # GAN training: loss = D(real) - D(fake). Otherwise, engine will 481 # complain that variables needed to do backward for the first forward 482 # (i.e., the `u` and `v` vectors) are changed in the second forward. 483 484 # Precondition 485 assert weight_mat.ndim > 1 486 487 for _ in range(n_power_iterations): 488 # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 489 # are the first left and right singular vectors. 490 # This power iteration produces approximations of `u` and `v`. 491 self._u = F.normalize( 492 torch.mv(weight_mat, self._v), # type: ignore[has-type] 493 dim=0, 494 eps=self.eps, 495 out=self._u, # type: ignore[has-type] 496 ) 497 self._v = F.normalize( 498 torch.mv(weight_mat.H, self._u), # type: ignore[has-type] 499 dim=0, 500 eps=self.eps, 501 out=self._v, # type: ignore[has-type] 502 ) 503 504 def forward(self, weight: torch.Tensor) -> torch.Tensor: 505 if weight.ndim == 1: 506 # Faster and more exact path, no need to approximate anything 507 return F.normalize(weight, dim=0, eps=self.eps) 508 else: 509 weight_mat = self._reshape_weight_to_matrix(weight) 510 if self.training: 511 self._power_method(weight_mat, self.n_power_iterations) 512 # See above on why we need to clone 513 u = self._u.clone(memory_format=torch.contiguous_format) 514 v = self._v.clone(memory_format=torch.contiguous_format) 515 # The proper way of computing this should be through F.bilinear, but 516 # it seems to have some efficiency issues: 517 # https://github.com/pytorch/pytorch/issues/58093 518 sigma = torch.vdot(u, torch.mv(weight_mat, v)) 519 return weight / sigma 520 521 def right_inverse(self, value: torch.Tensor) -> torch.Tensor: 522 # we may want to assert here that the passed value already 523 # satisfies constraints 524 return value 525 526 527def spectral_norm( 528 module: Module, 529 name: str = "weight", 530 n_power_iterations: int = 1, 531 eps: float = 1e-12, 532 dim: Optional[int] = None, 533) -> Module: 534 r"""Apply spectral normalization to a parameter in the given module. 535 536 .. math:: 537 \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 538 \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 539 540 When applied on a vector, it simplifies to 541 542 .. math:: 543 \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} 544 545 Spectral normalization stabilizes the training of discriminators (critics) 546 in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant 547 of the model. :math:`\sigma` is approximated performing one iteration of the 548 `power method`_ every time the weight is accessed. If the dimension of the 549 weight tensor is greater than 2, it is reshaped to 2D in power iteration 550 method to get spectral norm. 551 552 553 See `Spectral Normalization for Generative Adversarial Networks`_ . 554 555 .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration 556 .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 557 558 .. note:: 559 This function is implemented using the parametrization functionality 560 in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a 561 reimplementation of :func:`torch.nn.utils.spectral_norm`. 562 563 .. note:: 564 When this constraint is registered, the singular vectors associated to the largest 565 singular value are estimated rather than sampled at random. These are then updated 566 performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor 567 is accessed with the module on `training` mode. 568 569 .. note:: 570 If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, 571 is in training mode on removal, it will perform another power iteration. 572 If you'd like to avoid this iteration, set the module to eval mode 573 before its removal. 574 575 Args: 576 module (nn.Module): containing module 577 name (str, optional): name of weight parameter. Default: ``"weight"``. 578 n_power_iterations (int, optional): number of power iterations to 579 calculate spectral norm. Default: ``1``. 580 eps (float, optional): epsilon for numerical stability in 581 calculating norms. Default: ``1e-12``. 582 dim (int, optional): dimension corresponding to number of outputs. 583 Default: ``0``, except for modules that are instances of 584 ConvTranspose{1,2,3}d, when it is ``1`` 585 586 Returns: 587 The original module with a new parametrization registered to the specified 588 weight 589 590 Example:: 591 592 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 593 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 594 >>> snm = spectral_norm(nn.Linear(20, 40)) 595 >>> snm 596 ParametrizedLinear( 597 in_features=20, out_features=40, bias=True 598 (parametrizations): ModuleDict( 599 (weight): ParametrizationList( 600 (0): _SpectralNorm() 601 ) 602 ) 603 ) 604 >>> torch.linalg.matrix_norm(snm.weight, 2) 605 tensor(1.0081, grad_fn=<AmaxBackward0>) 606 """ 607 weight = getattr(module, name, None) 608 if not isinstance(weight, Tensor): 609 raise ValueError( 610 f"Module '{module}' has no parameter or buffer with name '{name}'" 611 ) 612 613 if dim is None: 614 if isinstance( 615 module, 616 ( 617 torch.nn.ConvTranspose1d, 618 torch.nn.ConvTranspose2d, 619 torch.nn.ConvTranspose3d, 620 ), 621 ): 622 dim = 1 623 else: 624 dim = 0 625 parametrize.register_parametrization( 626 module, name, _SpectralNorm(weight, n_power_iterations, dim, eps) 627 ) 628 return module 629