xref: /aosp_15_r20/external/pytorch/torch/nn/utils/spectral_norm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Spectral Normalization from https://arxiv.org/abs/1802.05957."""
3from typing import Any, Optional, TypeVar
4
5import torch
6import torch.nn.functional as F
7from torch.nn.modules import Module
8
9
10__all__ = [
11    "SpectralNorm",
12    "SpectralNormLoadStateDictPreHook",
13    "SpectralNormStateDictHook",
14    "spectral_norm",
15    "remove_spectral_norm",
16]
17
18
19class SpectralNorm:
20    # Invariant before and after each forward call:
21    #   u = F.normalize(W @ v)
22    # NB: At initialization, this invariant is not enforced
23
24    _version: int = 1
25    # At version 1:
26    #   made  `W` not a buffer,
27    #   added `v` as a buffer, and
28    #   made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
29    name: str
30    dim: int
31    n_power_iterations: int
32    eps: float
33
34    def __init__(
35        self,
36        name: str = "weight",
37        n_power_iterations: int = 1,
38        dim: int = 0,
39        eps: float = 1e-12,
40    ) -> None:
41        self.name = name
42        self.dim = dim
43        if n_power_iterations <= 0:
44            raise ValueError(
45                "Expected n_power_iterations to be positive, but "
46                f"got n_power_iterations={n_power_iterations}"
47            )
48        self.n_power_iterations = n_power_iterations
49        self.eps = eps
50
51    def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
52        weight_mat = weight
53        if self.dim != 0:
54            # permute dim to front
55            weight_mat = weight_mat.permute(
56                self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]
57            )
58        height = weight_mat.size(0)
59        return weight_mat.reshape(height, -1)
60
61    def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor:
62        # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
63        #     updated in power iteration **in-place**. This is very important
64        #     because in `DataParallel` forward, the vectors (being buffers) are
65        #     broadcast from the parallelized module to each module replica,
66        #     which is a new module object created on the fly. And each replica
67        #     runs its own spectral norm power iteration. So simply assigning
68        #     the updated vectors to the module this function runs on will cause
69        #     the update to be lost forever. And the next time the parallelized
70        #     module is replicated, the same randomly initialized vectors are
71        #     broadcast and used!
72        #
73        #     Therefore, to make the change propagate back, we rely on two
74        #     important behaviors (also enforced via tests):
75        #       1. `DataParallel` doesn't clone storage if the broadcast tensor
76        #          is already on correct device; and it makes sure that the
77        #          parallelized module is already on `device[0]`.
78        #       2. If the out tensor in `out=` kwarg has correct shape, it will
79        #          just fill in the values.
80        #     Therefore, since the same power iteration is performed on all
81        #     devices, simply updating the tensors in-place will make sure that
82        #     the module replica on `device[0]` will update the _u vector on the
83        #     parallelized module (by shared storage).
84        #
85        #    However, after we update `u` and `v` in-place, we need to **clone**
86        #    them before using them to normalize the weight. This is to support
87        #    backproping through two forward passes, e.g., the common pattern in
88        #    GAN training: loss = D(real) - D(fake). Otherwise, engine will
89        #    complain that variables needed to do backward for the first forward
90        #    (i.e., the `u` and `v` vectors) are changed in the second forward.
91        weight = getattr(module, self.name + "_orig")
92        u = getattr(module, self.name + "_u")
93        v = getattr(module, self.name + "_v")
94        weight_mat = self.reshape_weight_to_matrix(weight)
95
96        if do_power_iteration:
97            with torch.no_grad():
98                for _ in range(self.n_power_iterations):
99                    # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
100                    # are the first left and right singular vectors.
101                    # This power iteration produces approximations of `u` and `v`.
102                    v = F.normalize(
103                        torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v
104                    )
105                    u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
106                if self.n_power_iterations > 0:
107                    # See above on why we need to clone
108                    u = u.clone(memory_format=torch.contiguous_format)
109                    v = v.clone(memory_format=torch.contiguous_format)
110
111        sigma = torch.dot(u, torch.mv(weight_mat, v))
112        weight = weight / sigma
113        return weight
114
115    def remove(self, module: Module) -> None:
116        with torch.no_grad():
117            weight = self.compute_weight(module, do_power_iteration=False)
118        delattr(module, self.name)
119        delattr(module, self.name + "_u")
120        delattr(module, self.name + "_v")
121        delattr(module, self.name + "_orig")
122        module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
123
124    def __call__(self, module: Module, inputs: Any) -> None:
125        setattr(
126            module,
127            self.name,
128            self.compute_weight(module, do_power_iteration=module.training),
129        )
130
131    def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
132        # Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)`
133        # (the invariant at top of this class) and `u @ W @ v = sigma`.
134        # This uses pinverse in case W^T W is not invertible.
135        v = torch.linalg.multi_dot(
136            [weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)]
137        ).squeeze(1)
138        return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
139
140    @staticmethod
141    def apply(
142        module: Module, name: str, n_power_iterations: int, dim: int, eps: float
143    ) -> "SpectralNorm":
144        for hook in module._forward_pre_hooks.values():
145            if isinstance(hook, SpectralNorm) and hook.name == name:
146                raise RuntimeError(
147                    f"Cannot register two spectral_norm hooks on the same parameter {name}"
148                )
149
150        fn = SpectralNorm(name, n_power_iterations, dim, eps)
151        weight = module._parameters[name]
152        if weight is None:
153            raise ValueError(
154                f"`SpectralNorm` cannot be applied as parameter `{name}` is None"
155            )
156        if isinstance(weight, torch.nn.parameter.UninitializedParameter):
157            raise ValueError(
158                "The module passed to `SpectralNorm` can't have uninitialized parameters. "
159                "Make sure to run the dummy forward before applying spectral normalization"
160            )
161
162        with torch.no_grad():
163            weight_mat = fn.reshape_weight_to_matrix(weight)
164
165            h, w = weight_mat.size()
166            # randomly initialize `u` and `v`
167            u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
168            v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
169
170        delattr(module, fn.name)
171        module.register_parameter(fn.name + "_orig", weight)
172        # We still need to assign weight back as fn.name because all sorts of
173        # things may assume that it exists, e.g., when initializing weights.
174        # However, we can't directly assign as it could be an nn.Parameter and
175        # gets added as a parameter. Instead, we register weight.data as a plain
176        # attribute.
177        setattr(module, fn.name, weight.data)
178        module.register_buffer(fn.name + "_u", u)
179        module.register_buffer(fn.name + "_v", v)
180
181        module.register_forward_pre_hook(fn)
182        module._register_state_dict_hook(SpectralNormStateDictHook(fn))
183        module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
184        return fn
185
186
187# This is a top level class because Py2 pickle doesn't like inner class nor an
188# instancemethod.
189class SpectralNormLoadStateDictPreHook:
190    # See docstring of SpectralNorm._version on the changes to spectral_norm.
191    def __init__(self, fn) -> None:
192        self.fn = fn
193
194    # For state_dict with version None, (assuming that it has gone through at
195    # least one training forward), we have
196    #
197    #    u = F.normalize(W_orig @ v)
198    #    W = W_orig / sigma, where sigma = u @ W_orig @ v
199    #
200    # To compute `v`, we solve `W_orig @ x = u`, and let
201    #    v = x / (u @ W_orig @ x) * (W / W_orig).
202    def __call__(
203        self,
204        state_dict,
205        prefix,
206        local_metadata,
207        strict,
208        missing_keys,
209        unexpected_keys,
210        error_msgs,
211    ) -> None:
212        fn = self.fn
213        version = local_metadata.get("spectral_norm", {}).get(
214            fn.name + ".version", None
215        )
216        if version is None or version < 1:
217            weight_key = prefix + fn.name
218            if (
219                version is None
220                and all(weight_key + s in state_dict for s in ("_orig", "_u", "_v"))
221                and weight_key not in state_dict
222            ):
223                # Detect if it is the updated state dict and just missing metadata.
224                # This could happen if the users are crafting a state dict themselves,
225                # so we just pretend that this is the newest.
226                return
227            has_missing_keys = False
228            for suffix in ("_orig", "", "_u"):
229                key = weight_key + suffix
230                if key not in state_dict:
231                    has_missing_keys = True
232                    if strict:
233                        missing_keys.append(key)
234            if has_missing_keys:
235                return
236            with torch.no_grad():
237                weight_orig = state_dict[weight_key + "_orig"]
238                weight = state_dict.pop(weight_key)
239                sigma = (weight_orig / weight).mean()
240                weight_mat = fn.reshape_weight_to_matrix(weight_orig)
241                u = state_dict[weight_key + "_u"]
242                v = fn._solve_v_and_rescale(weight_mat, u, sigma)
243                state_dict[weight_key + "_v"] = v
244
245
246# This is a top level class because Py2 pickle doesn't like inner class nor an
247# instancemethod.
248class SpectralNormStateDictHook:
249    # See docstring of SpectralNorm._version on the changes to spectral_norm.
250    def __init__(self, fn) -> None:
251        self.fn = fn
252
253    def __call__(self, module, state_dict, prefix, local_metadata) -> None:
254        if "spectral_norm" not in local_metadata:
255            local_metadata["spectral_norm"] = {}
256        key = self.fn.name + ".version"
257        if key in local_metadata["spectral_norm"]:
258            raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}")
259        local_metadata["spectral_norm"][key] = self.fn._version
260
261
262T_module = TypeVar("T_module", bound=Module)
263
264
265def spectral_norm(
266    module: T_module,
267    name: str = "weight",
268    n_power_iterations: int = 1,
269    eps: float = 1e-12,
270    dim: Optional[int] = None,
271) -> T_module:
272    r"""Apply spectral normalization to a parameter in the given module.
273
274    .. math::
275        \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
276        \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
277
278    Spectral normalization stabilizes the training of discriminators (critics)
279    in Generative Adversarial Networks (GANs) by rescaling the weight tensor
280    with spectral norm :math:`\sigma` of the weight matrix calculated using
281    power iteration method. If the dimension of the weight tensor is greater
282    than 2, it is reshaped to 2D in power iteration method to get spectral
283    norm. This is implemented via a hook that calculates spectral norm and
284    rescales weight before every :meth:`~Module.forward` call.
285
286    See `Spectral Normalization for Generative Adversarial Networks`_ .
287
288    .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
289
290    Args:
291        module (nn.Module): containing module
292        name (str, optional): name of weight parameter
293        n_power_iterations (int, optional): number of power iterations to
294            calculate spectral norm
295        eps (float, optional): epsilon for numerical stability in
296            calculating norms
297        dim (int, optional): dimension corresponding to number of outputs,
298            the default is ``0``, except for modules that are instances of
299            ConvTranspose{1,2,3}d, when it is ``1``
300
301    Returns:
302        The original module with the spectral norm hook
303
304    .. note::
305        This function has been reimplemented as
306        :func:`torch.nn.utils.parametrizations.spectral_norm` using the new
307        parametrization functionality in
308        :func:`torch.nn.utils.parametrize.register_parametrization`. Please use
309        the newer version. This function will be deprecated in a future version
310        of PyTorch.
311
312    Example::
313
314        >>> m = spectral_norm(nn.Linear(20, 40))
315        >>> m
316        Linear(in_features=20, out_features=40, bias=True)
317        >>> m.weight_u.size()
318        torch.Size([40])
319
320    """
321    if dim is None:
322        if isinstance(
323            module,
324            (
325                torch.nn.ConvTranspose1d,
326                torch.nn.ConvTranspose2d,
327                torch.nn.ConvTranspose3d,
328            ),
329        ):
330            dim = 1
331        else:
332            dim = 0
333    SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
334    return module
335
336
337def remove_spectral_norm(module: T_module, name: str = "weight") -> T_module:
338    r"""Remove the spectral normalization reparameterization from a module.
339
340    Args:
341        module (Module): containing module
342        name (str, optional): name of weight parameter
343
344    Example:
345        >>> m = spectral_norm(nn.Linear(40, 10))
346        >>> remove_spectral_norm(m)
347    """
348    for k, hook in module._forward_pre_hooks.items():
349        if isinstance(hook, SpectralNorm) and hook.name == name:
350            hook.remove(module)
351            del module._forward_pre_hooks[k]
352            break
353    else:
354        raise ValueError(f"spectral_norm of '{name}' not found in {module}")
355
356    for k, hook in module._state_dict_hooks.items():
357        if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
358            del module._state_dict_hooks[k]
359            break
360
361    for k, hook in module._load_state_dict_pre_hooks.items():
362        if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
363            del module._load_state_dict_pre_hooks[k]
364            break
365
366    return module
367