1# mypy: allow-untyped-defs 2"""Spectral Normalization from https://arxiv.org/abs/1802.05957.""" 3from typing import Any, Optional, TypeVar 4 5import torch 6import torch.nn.functional as F 7from torch.nn.modules import Module 8 9 10__all__ = [ 11 "SpectralNorm", 12 "SpectralNormLoadStateDictPreHook", 13 "SpectralNormStateDictHook", 14 "spectral_norm", 15 "remove_spectral_norm", 16] 17 18 19class SpectralNorm: 20 # Invariant before and after each forward call: 21 # u = F.normalize(W @ v) 22 # NB: At initialization, this invariant is not enforced 23 24 _version: int = 1 25 # At version 1: 26 # made `W` not a buffer, 27 # added `v` as a buffer, and 28 # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. 29 name: str 30 dim: int 31 n_power_iterations: int 32 eps: float 33 34 def __init__( 35 self, 36 name: str = "weight", 37 n_power_iterations: int = 1, 38 dim: int = 0, 39 eps: float = 1e-12, 40 ) -> None: 41 self.name = name 42 self.dim = dim 43 if n_power_iterations <= 0: 44 raise ValueError( 45 "Expected n_power_iterations to be positive, but " 46 f"got n_power_iterations={n_power_iterations}" 47 ) 48 self.n_power_iterations = n_power_iterations 49 self.eps = eps 50 51 def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: 52 weight_mat = weight 53 if self.dim != 0: 54 # permute dim to front 55 weight_mat = weight_mat.permute( 56 self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim] 57 ) 58 height = weight_mat.size(0) 59 return weight_mat.reshape(height, -1) 60 61 def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor: 62 # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 63 # updated in power iteration **in-place**. This is very important 64 # because in `DataParallel` forward, the vectors (being buffers) are 65 # broadcast from the parallelized module to each module replica, 66 # which is a new module object created on the fly. And each replica 67 # runs its own spectral norm power iteration. So simply assigning 68 # the updated vectors to the module this function runs on will cause 69 # the update to be lost forever. And the next time the parallelized 70 # module is replicated, the same randomly initialized vectors are 71 # broadcast and used! 72 # 73 # Therefore, to make the change propagate back, we rely on two 74 # important behaviors (also enforced via tests): 75 # 1. `DataParallel` doesn't clone storage if the broadcast tensor 76 # is already on correct device; and it makes sure that the 77 # parallelized module is already on `device[0]`. 78 # 2. If the out tensor in `out=` kwarg has correct shape, it will 79 # just fill in the values. 80 # Therefore, since the same power iteration is performed on all 81 # devices, simply updating the tensors in-place will make sure that 82 # the module replica on `device[0]` will update the _u vector on the 83 # parallelized module (by shared storage). 84 # 85 # However, after we update `u` and `v` in-place, we need to **clone** 86 # them before using them to normalize the weight. This is to support 87 # backproping through two forward passes, e.g., the common pattern in 88 # GAN training: loss = D(real) - D(fake). Otherwise, engine will 89 # complain that variables needed to do backward for the first forward 90 # (i.e., the `u` and `v` vectors) are changed in the second forward. 91 weight = getattr(module, self.name + "_orig") 92 u = getattr(module, self.name + "_u") 93 v = getattr(module, self.name + "_v") 94 weight_mat = self.reshape_weight_to_matrix(weight) 95 96 if do_power_iteration: 97 with torch.no_grad(): 98 for _ in range(self.n_power_iterations): 99 # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 100 # are the first left and right singular vectors. 101 # This power iteration produces approximations of `u` and `v`. 102 v = F.normalize( 103 torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v 104 ) 105 u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) 106 if self.n_power_iterations > 0: 107 # See above on why we need to clone 108 u = u.clone(memory_format=torch.contiguous_format) 109 v = v.clone(memory_format=torch.contiguous_format) 110 111 sigma = torch.dot(u, torch.mv(weight_mat, v)) 112 weight = weight / sigma 113 return weight 114 115 def remove(self, module: Module) -> None: 116 with torch.no_grad(): 117 weight = self.compute_weight(module, do_power_iteration=False) 118 delattr(module, self.name) 119 delattr(module, self.name + "_u") 120 delattr(module, self.name + "_v") 121 delattr(module, self.name + "_orig") 122 module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) 123 124 def __call__(self, module: Module, inputs: Any) -> None: 125 setattr( 126 module, 127 self.name, 128 self.compute_weight(module, do_power_iteration=module.training), 129 ) 130 131 def _solve_v_and_rescale(self, weight_mat, u, target_sigma): 132 # Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)` 133 # (the invariant at top of this class) and `u @ W @ v = sigma`. 134 # This uses pinverse in case W^T W is not invertible. 135 v = torch.linalg.multi_dot( 136 [weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)] 137 ).squeeze(1) 138 return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) 139 140 @staticmethod 141 def apply( 142 module: Module, name: str, n_power_iterations: int, dim: int, eps: float 143 ) -> "SpectralNorm": 144 for hook in module._forward_pre_hooks.values(): 145 if isinstance(hook, SpectralNorm) and hook.name == name: 146 raise RuntimeError( 147 f"Cannot register two spectral_norm hooks on the same parameter {name}" 148 ) 149 150 fn = SpectralNorm(name, n_power_iterations, dim, eps) 151 weight = module._parameters[name] 152 if weight is None: 153 raise ValueError( 154 f"`SpectralNorm` cannot be applied as parameter `{name}` is None" 155 ) 156 if isinstance(weight, torch.nn.parameter.UninitializedParameter): 157 raise ValueError( 158 "The module passed to `SpectralNorm` can't have uninitialized parameters. " 159 "Make sure to run the dummy forward before applying spectral normalization" 160 ) 161 162 with torch.no_grad(): 163 weight_mat = fn.reshape_weight_to_matrix(weight) 164 165 h, w = weight_mat.size() 166 # randomly initialize `u` and `v` 167 u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) 168 v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) 169 170 delattr(module, fn.name) 171 module.register_parameter(fn.name + "_orig", weight) 172 # We still need to assign weight back as fn.name because all sorts of 173 # things may assume that it exists, e.g., when initializing weights. 174 # However, we can't directly assign as it could be an nn.Parameter and 175 # gets added as a parameter. Instead, we register weight.data as a plain 176 # attribute. 177 setattr(module, fn.name, weight.data) 178 module.register_buffer(fn.name + "_u", u) 179 module.register_buffer(fn.name + "_v", v) 180 181 module.register_forward_pre_hook(fn) 182 module._register_state_dict_hook(SpectralNormStateDictHook(fn)) 183 module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) 184 return fn 185 186 187# This is a top level class because Py2 pickle doesn't like inner class nor an 188# instancemethod. 189class SpectralNormLoadStateDictPreHook: 190 # See docstring of SpectralNorm._version on the changes to spectral_norm. 191 def __init__(self, fn) -> None: 192 self.fn = fn 193 194 # For state_dict with version None, (assuming that it has gone through at 195 # least one training forward), we have 196 # 197 # u = F.normalize(W_orig @ v) 198 # W = W_orig / sigma, where sigma = u @ W_orig @ v 199 # 200 # To compute `v`, we solve `W_orig @ x = u`, and let 201 # v = x / (u @ W_orig @ x) * (W / W_orig). 202 def __call__( 203 self, 204 state_dict, 205 prefix, 206 local_metadata, 207 strict, 208 missing_keys, 209 unexpected_keys, 210 error_msgs, 211 ) -> None: 212 fn = self.fn 213 version = local_metadata.get("spectral_norm", {}).get( 214 fn.name + ".version", None 215 ) 216 if version is None or version < 1: 217 weight_key = prefix + fn.name 218 if ( 219 version is None 220 and all(weight_key + s in state_dict for s in ("_orig", "_u", "_v")) 221 and weight_key not in state_dict 222 ): 223 # Detect if it is the updated state dict and just missing metadata. 224 # This could happen if the users are crafting a state dict themselves, 225 # so we just pretend that this is the newest. 226 return 227 has_missing_keys = False 228 for suffix in ("_orig", "", "_u"): 229 key = weight_key + suffix 230 if key not in state_dict: 231 has_missing_keys = True 232 if strict: 233 missing_keys.append(key) 234 if has_missing_keys: 235 return 236 with torch.no_grad(): 237 weight_orig = state_dict[weight_key + "_orig"] 238 weight = state_dict.pop(weight_key) 239 sigma = (weight_orig / weight).mean() 240 weight_mat = fn.reshape_weight_to_matrix(weight_orig) 241 u = state_dict[weight_key + "_u"] 242 v = fn._solve_v_and_rescale(weight_mat, u, sigma) 243 state_dict[weight_key + "_v"] = v 244 245 246# This is a top level class because Py2 pickle doesn't like inner class nor an 247# instancemethod. 248class SpectralNormStateDictHook: 249 # See docstring of SpectralNorm._version on the changes to spectral_norm. 250 def __init__(self, fn) -> None: 251 self.fn = fn 252 253 def __call__(self, module, state_dict, prefix, local_metadata) -> None: 254 if "spectral_norm" not in local_metadata: 255 local_metadata["spectral_norm"] = {} 256 key = self.fn.name + ".version" 257 if key in local_metadata["spectral_norm"]: 258 raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}") 259 local_metadata["spectral_norm"][key] = self.fn._version 260 261 262T_module = TypeVar("T_module", bound=Module) 263 264 265def spectral_norm( 266 module: T_module, 267 name: str = "weight", 268 n_power_iterations: int = 1, 269 eps: float = 1e-12, 270 dim: Optional[int] = None, 271) -> T_module: 272 r"""Apply spectral normalization to a parameter in the given module. 273 274 .. math:: 275 \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 276 \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 277 278 Spectral normalization stabilizes the training of discriminators (critics) 279 in Generative Adversarial Networks (GANs) by rescaling the weight tensor 280 with spectral norm :math:`\sigma` of the weight matrix calculated using 281 power iteration method. If the dimension of the weight tensor is greater 282 than 2, it is reshaped to 2D in power iteration method to get spectral 283 norm. This is implemented via a hook that calculates spectral norm and 284 rescales weight before every :meth:`~Module.forward` call. 285 286 See `Spectral Normalization for Generative Adversarial Networks`_ . 287 288 .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 289 290 Args: 291 module (nn.Module): containing module 292 name (str, optional): name of weight parameter 293 n_power_iterations (int, optional): number of power iterations to 294 calculate spectral norm 295 eps (float, optional): epsilon for numerical stability in 296 calculating norms 297 dim (int, optional): dimension corresponding to number of outputs, 298 the default is ``0``, except for modules that are instances of 299 ConvTranspose{1,2,3}d, when it is ``1`` 300 301 Returns: 302 The original module with the spectral norm hook 303 304 .. note:: 305 This function has been reimplemented as 306 :func:`torch.nn.utils.parametrizations.spectral_norm` using the new 307 parametrization functionality in 308 :func:`torch.nn.utils.parametrize.register_parametrization`. Please use 309 the newer version. This function will be deprecated in a future version 310 of PyTorch. 311 312 Example:: 313 314 >>> m = spectral_norm(nn.Linear(20, 40)) 315 >>> m 316 Linear(in_features=20, out_features=40, bias=True) 317 >>> m.weight_u.size() 318 torch.Size([40]) 319 320 """ 321 if dim is None: 322 if isinstance( 323 module, 324 ( 325 torch.nn.ConvTranspose1d, 326 torch.nn.ConvTranspose2d, 327 torch.nn.ConvTranspose3d, 328 ), 329 ): 330 dim = 1 331 else: 332 dim = 0 333 SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 334 return module 335 336 337def remove_spectral_norm(module: T_module, name: str = "weight") -> T_module: 338 r"""Remove the spectral normalization reparameterization from a module. 339 340 Args: 341 module (Module): containing module 342 name (str, optional): name of weight parameter 343 344 Example: 345 >>> m = spectral_norm(nn.Linear(40, 10)) 346 >>> remove_spectral_norm(m) 347 """ 348 for k, hook in module._forward_pre_hooks.items(): 349 if isinstance(hook, SpectralNorm) and hook.name == name: 350 hook.remove(module) 351 del module._forward_pre_hooks[k] 352 break 353 else: 354 raise ValueError(f"spectral_norm of '{name}' not found in {module}") 355 356 for k, hook in module._state_dict_hooks.items(): 357 if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name: 358 del module._state_dict_hooks[k] 359 break 360 361 for k, hook in module._load_state_dict_pre_hooks.items(): 362 if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name: 363 del module._load_state_dict_pre_hooks[k] 364 break 365 366 return module 367