xref: /aosp_15_r20/external/pytorch/test/nn/test_parametrization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import pickle
3from copy import deepcopy
4from itertools import product
5
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9import torch.nn.init as init
10import torch.nn.utils.parametrize as parametrize
11from torch import Tensor
12from torch.__future__ import get_swap_module_params_on_conversion
13from torch.nn import Buffer, Parameter
14from torch.testing._internal.common_cuda import TEST_MULTIGPU
15from torch.testing._internal.common_device_type import instantiate_device_type_tests
16from torch.testing._internal.common_nn import NNTestCase
17from torch.testing._internal.common_utils import (
18    gradcheck,
19    instantiate_parametrized_tests,
20    run_tests,
21    set_default_dtype,
22    skipIfNoLapack,
23    skipIfTorchDynamo,
24    swap,
25    TemporaryFileName,
26)
27from torch.testing._internal.two_tensor import TwoTensor
28
29
30class TestNNParametrization(NNTestCase):
31    _do_cuda_memory_leak_check = True
32    _do_cuda_non_default_stream = True
33
34    # FIXME: Rewrite this test using functions not depending on LAPACK
35    #        and remove the `@skipIfNoLapack` (see #70995)
36    # torch/nn/utils/parametrize
37    @skipIfNoLapack
38    @swap([True, False])
39    def test_register_and_remove_parametrization(self):
40        r"""Test that it is possible to add a few parametrizations
41        on a parameter or a buffer and that removing them restores the initial state
42        It also tests that backpropagating through them works as expected
43        """
44
45        # Define a couple matrix parametrizations
46        class Skew(nn.Module):
47            def forward(self, X):
48                X = X.tril(-1)
49                return X - X.T
50
51        class Orthogonal(nn.Module):
52            def forward(self, X):
53                # Cayley map
54                # If X is skew-symmetric it returns an orthogonal matrix
55                Id = torch.eye(X.size(0), device=X.device)
56                # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous
57                # and autograd raises a performance warning.
58                # This happens when we remove the parametrization with leave_parametrized=True,
59                # which does a set_ with a non-contiguous tensor while the gradient is contiguous
60                return torch.linalg.solve(Id + X, Id - X).contiguous()
61
62        class Resize(nn.Module):
63            def forward(self, X):
64                return X[[0]]
65
66        class NoResize(nn.Module):
67            def forward(self, X):
68                return X
69
70        # Define a couple vector parametrizations
71        class FirstZero(nn.Module):
72            def forward(self, x):
73                return torch.cat([x.new_zeros(1), x[1:]])
74
75        class LastZero(nn.Module):
76            def forward(self, x):
77                return torch.cat([x[:-1], x.new_zeros(1)])
78
79        model = nn.Linear(8, 8)
80        initial_weight_id = id(model.weight)
81        initial_bias_id = id(model.bias)
82        initial_model = deepcopy(model)
83
84        # Test unsafe flag
85        with self.assertRaisesRegex(
86            ValueError,
87            "Registering a parametrization may not change the shape of the tensor",
88        ):
89            parametrize.register_parametrization(
90                model, "weight", Resize()
91            )  # default unsafe = False
92            model(torch.ones(8, 8))
93
94        # One parametrization with unsafe=True
95        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
96        self.assertTrue(hasattr(model, "parametrizations"))
97        self.assertTrue(parametrize.is_parametrized(model))
98        self.assertTrue(parametrize.is_parametrized(model, "weight"))
99        self.assertFalse(parametrize.is_parametrized(model, "bias"))
100        self.assertNotIn("weight", model._parameters)
101        self.assertTrue(model.weight.shape[0] == 1)
102        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
103        self.assertFalse(hasattr(model, "parametrizations"))
104        self.assertEqual(model.weight, initial_model.weight)
105        self.assertEqual(id(model.weight), initial_weight_id)
106        self.assertEqual(model.__class__, nn.Linear)
107
108        # Two parametrizations with unsafe=True
109        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
110        parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False)
111        self.assertTrue(hasattr(model, "parametrizations"))
112        self.assertTrue(parametrize.is_parametrized(model))
113        self.assertTrue(parametrize.is_parametrized(model, "weight"))
114        self.assertFalse(parametrize.is_parametrized(model, "bias"))
115        self.assertNotIn("weight", model._parameters)
116        self.assertTrue(model.weight.shape[0] == 1)
117        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
118        self.assertFalse(hasattr(model, "parametrizations"))
119        self.assertEqual(model.weight, initial_model.weight)
120        self.assertEqual(id(model.weight), initial_weight_id)
121        self.assertEqual(model.__class__, nn.Linear)
122
123        # Test unsafe flag doesn't change expected behavior
124        parametrize.register_parametrization(model, "weight", Skew(), unsafe=True)
125        self.assertTrue(hasattr(model, "parametrizations"))
126        self.assertTrue(parametrize.is_parametrized(model))
127        self.assertTrue(parametrize.is_parametrized(model, "weight"))
128        self.assertFalse(parametrize.is_parametrized(model, "bias"))
129        self.assertNotIn("weight", model._parameters)
130        # Result should be skew-symmetric
131        A = model.weight
132        self.assertEqual(A, -A.T)
133        if get_swap_module_params_on_conversion():
134            # When using the swap_tensors path, this is needed so that the autograd
135            # graph is not alive anymore.
136            del A
137        # Remove and check consistency
138        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
139        self.assertFalse(hasattr(model, "parametrizations"))
140        self.assertEqual(model.weight, initial_model.weight)
141        self.assertEqual(id(model.weight), initial_weight_id)
142        self.assertEqual(model.__class__, nn.Linear)
143
144        # Test one parametrization
145        parametrize.register_parametrization(model, "weight", Skew())
146        self.assertTrue(hasattr(model, "parametrizations"))
147        self.assertTrue(parametrize.is_parametrized(model))
148        self.assertTrue(parametrize.is_parametrized(model, "weight"))
149        self.assertFalse(parametrize.is_parametrized(model, "bias"))
150        self.assertNotIn("weight", model._parameters)
151        # Result should be skew-symmetric
152        A = model.weight
153        self.assertEqual(A, -A.T)
154        if get_swap_module_params_on_conversion():
155            # When using the swap_tensors path, this is needed so that the autograd
156            # graph is not alive anymore.
157            del A
158        # Remove and check consistency
159        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
160        self.assertFalse(hasattr(model, "parametrizations"))
161        self.assertEqual(model.weight, initial_model.weight)
162        self.assertEqual(id(model.weight), initial_weight_id)
163        self.assertEqual(model.__class__, nn.Linear)
164
165        # Test two parametrizations at the same time and removing them
166        parametrize.register_parametrization(model, "weight", Skew())
167        parametrize.register_parametrization(model, "weight", Orthogonal())
168        # Result should be orthogonal
169        X = model.weight
170        Id = torch.eye(X.size(0), device=X.device)
171        self.assertEqual(X.T @ X, Id)
172        if get_swap_module_params_on_conversion():
173            # When using the swap_tensors path, this is needed so that the autograd
174            # graph is not alive anymore.
175            del X
176        # Structure tests
177        self.assertTrue(hasattr(model, "parametrizations"))
178        self.assertTrue(parametrize.is_parametrized(model))
179        self.assertTrue(parametrize.is_parametrized(model, "weight"))
180        self.assertFalse(parametrize.is_parametrized(model, "bias"))
181        self.assertIn("weight", model.parametrizations)
182        self.assertNotIn("weight", model._parameters)
183        # Remove
184        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
185        self.assertEqual(model.weight, initial_model.weight)
186        self.assertEqual(id(model.weight), initial_weight_id)
187        self.assertFalse(hasattr(model, "parametrizations"))
188        self.assertEqual(model.__class__, nn.Linear)
189
190        # Add everything
191        parametrize.register_parametrization(model, "weight", Skew())
192        parametrize.register_parametrization(model, "weight", Orthogonal())
193        parametrize.register_parametrization(model, "bias", FirstZero())
194        parametrize.register_parametrization(model, "bias", LastZero())
195
196        # Basic tests
197        self.assertTrue(parametrize.is_parametrized(model))
198        self.assertTrue(parametrize.is_parametrized(model, "weight"))
199        self.assertTrue(parametrize.is_parametrized(model, "bias"))
200        self.assertEqual(model.bias[0].item(), 0.0)
201        self.assertEqual(model.bias[-1].item(), 0.0)
202        self.assertEqual(
203            len(list(model.parameters())), 2
204        )  # Nothing weird has happpened
205        # Should not throw
206
207        sgd = torch.optim.SGD(model.parameters(), lr=0.01)
208
209        weight_copy = model.weight.clone()
210        bias_copy = model.bias.clone()
211        sgd.zero_grad()
212        (model.weight.T @ model.bias).sum().backward()
213        sgd.step()
214        self.assertNotEqual(model.weight, weight_copy)
215        self.assertNotEqual(model.bias, bias_copy)
216
217        # Remove first parametrization.
218        # Check that the model is still parametrized and so is the second parameter
219        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
220        self.assertTrue(parametrize.is_parametrized(model))  # Still parametrized
221        self.assertFalse(
222            parametrize.is_parametrized(model, "weight")
223        )  # Parametrization removed
224        self.assertTrue(
225            parametrize.is_parametrized(model, "bias")
226        )  # Still parametrized
227        self.assertEqual(model.bias[0].item(), 0.0)  # Still parametrized
228        self.assertEqual(model.bias[-1].item(), 0.0)  # Still parametrized
229        self.assertNotEqual(model.weight, initial_model.weight)  # Has been updated
230        self.assertEqual(id(model.weight), initial_weight_id)  # Keeps the same id
231        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happened
232        # Should not throw
233        weight_copy = model.weight.clone()
234        bias_copy = model.bias.clone()
235        sgd.zero_grad()
236        (model.weight.T @ model.bias).sum().backward()
237        sgd.step()
238        self.assertNotEqual(model.weight, weight_copy)
239        self.assertNotEqual(model.bias, bias_copy)
240
241        # Remove the second parametrization.
242        # Check that the module is not parametrized
243        parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
244        self.assertFalse(parametrize.is_parametrized(model))  # Not parametrized
245        self.assertNotEqual(model.bias, initial_model.bias)  # Has been updated
246        self.assertNotEqual(model.bias[0].item(), 0.0)  # Not parametrized
247        self.assertNotEqual(model.bias[-1].item(), 0.0)  # Not parametrized
248        self.assertEqual(id(model.bias), initial_bias_id)  # Keeps the same id
249        self.assertFalse(
250            hasattr(model, "parametrizations")
251        )  # Not parametrized the module
252        self.assertEqual(model.__class__, nn.Linear)  # Resores the previous class
253        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happeed
254
255        # Should not throw things are updated
256        weight_copy = model.weight.clone()
257        bias_copy = model.bias.clone()
258        sgd.zero_grad()
259        (model.weight.T @ model.bias).sum().backward()
260        sgd.step()
261        self.assertNotEqual(model.weight, weight_copy)
262        self.assertNotEqual(model.bias, bias_copy)
263        if get_swap_module_params_on_conversion():
264            # When using the swap_tensors path, this is needed so that the autograd
265            # graph is not alive anymore.
266            del weight_copy, bias_copy
267
268        # Test leave_parametrized=True
269        for _ in range(2):
270            parametrize.register_parametrization(model, "weight", Skew())
271            parametrize.register_parametrization(model, "weight", Orthogonal())
272            parametrize.remove_parametrizations(
273                model, "weight", leave_parametrized=True
274            )
275            # We didn't change the dtype nor had multiple inputs, so the id should be the same
276            self.assertEqual(id(model.weight), initial_weight_id)
277            self.assertEqual(id(model.bias), initial_bias_id)
278
279            # Should not throw. Things are updated
280            weight_copy = model.weight.clone()
281            bias_copy = model.bias.clone()
282            sgd.zero_grad()
283            (model.weight.T @ model.bias).sum().backward()
284            sgd.step()
285            self.assertNotEqual(model.weight, weight_copy)
286            self.assertNotEqual(model.bias, bias_copy)
287            if get_swap_module_params_on_conversion():
288                # When using the swap_tensors path, this is needed so that the autograd
289                # graph is not alive anymore.
290                del weight_copy, bias_copy
291
292    @swap([True, False])
293    def test_register_and_remove_nested_parametrization(self):
294        r"""Test that it is possible to nest the parametrizations
295        meaning that the original param is parametrized again
296        """
297
298        class Skew(nn.Module):
299            def forward(self, X):
300                X = X.tril(-1)
301                return X - X.T
302
303        model = nn.Linear(8, 8)
304        # Add top level parametrization
305        parametrize.register_parametrization(model, "weight", Skew())
306        self.assertTrue(hasattr(model, "parametrizations"))
307        self.assertTrue(parametrize.is_parametrized(model))
308        self.assertTrue(parametrize.is_parametrized(model, "weight"))
309        self.assertFalse(parametrize.is_parametrized(model, "bias"))
310        self.assertNotIn("weight", model._parameters)
311        # Result should be skew-symmetric
312        A = model.weight
313        self.assertEqual(A, -A.T)
314        if get_swap_module_params_on_conversion():
315            # When using the swap_tensors path, this is needed so that the autograd
316            # graph is not alive anymore.
317            del A
318
319        # Add nested parametrization
320        param_mod = model.parametrizations.weight
321        self.assertFalse(hasattr(param_mod, "parametrizations"))
322        self.assertFalse(parametrize.is_parametrized(param_mod))
323        self.assertFalse(parametrize.is_parametrized(param_mod, "original"))
324
325        parametrize.register_parametrization(param_mod, "original", Skew())
326        self.assertTrue(hasattr(param_mod, "parametrizations"))
327        self.assertTrue(parametrize.is_parametrized(param_mod))
328        self.assertTrue(parametrize.is_parametrized(param_mod, "original"))
329        self.assertNotIn("original", param_mod._parameters)
330        # Result should be skew-symmetric
331        A = param_mod.original
332        self.assertEqual(A, -A.T)
333
334        # Remove nested param and check consistency
335        parametrize.remove_parametrizations(
336            param_mod, "original", leave_parametrized=False
337        )
338        self.assertFalse(hasattr(param_mod, "parametrizations"))
339        self.assertEqual(param_mod.__class__, parametrize.ParametrizationList)
340
341        # Remove top level and check consistency
342        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
343        self.assertFalse(hasattr(model, "parametrizations"))
344        self.assertEqual(model.__class__, nn.Linear)
345
346    @swap([True, False])
347    def test_register_and_remove_buffer_parametrization(self):
348        r"""Test that it is possible to add and remove parametrizations on buffers"""
349
350        # Define a couple vector parametrizations
351        class FirstZero(nn.Module):
352            def forward(self, x):
353                return torch.cat([x.new_zeros(1), x[1:]])
354
355        class LastZero(nn.Module):
356            def forward(self, x):
357                return torch.cat([x[:-1], x.new_zeros(1)])
358
359        model = nn.Linear(8, 8)
360
361        # Instantiate parametrizations on buffers. It should work as expected
362        delattr(model, "bias")
363        model.bias = Buffer(torch.ones(8))
364        parametrize.register_parametrization(model, "bias", FirstZero())
365        parametrize.register_parametrization(model, "bias", LastZero())
366        self.assertTrue(parametrize.is_parametrized(model))
367        self.assertTrue(parametrize.is_parametrized(model, "bias"))
368        self.assertEqual(model.bias[0].item(), 0.0)
369        self.assertEqual(model.bias[-1].item(), 0.0)
370        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
371        self.assertEqual(len(list(model.parameters())), 1)
372
373        # Remove parametrizations on buffers. It should work as expected
374        parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
375        self.assertFalse(parametrize.is_parametrized(model))
376        self.assertFalse(parametrize.is_parametrized(model, "bias"))
377        self.assertEqual(model.bias[0].item(), 0.0)
378        self.assertEqual(model.bias[-1].item(), 0.0)
379        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
380        self.assertEqual(len(list(model.parameters())), 1)
381
382    # FIXME: Rewrite this test using functions not depending on LAPACK
383    #        and remove the `@skipIfNoLapack` (see #70995)
384    @skipIfNoLapack
385    @swap([True, False])
386    def test_serialization_parametrization(self):
387        r"""Test that it is possible to serialize a parametrized model via state_dict"""
388
389        # A stateful parametrization
390        class Orthogonal(nn.Module):
391            def __init__(self, n):
392                super().__init__()
393                self.id = Buffer(torch.eye(n))
394                self.B = Buffer(torch.empty(n, n))
395                init.orthogonal_(self.B)
396
397            def forward(self, X):
398                A = X.triu(1)
399                A = A - A.T
400                return self.B @ torch.linalg.solve(self.id + A, self.id - A)
401
402        def get_model():
403            model = torch.nn.Sequential(
404                torch.nn.Linear(5, 5),
405                torch.nn.ReLU(),
406                torch.nn.Linear(5, 1),
407            )
408
409            parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
410            return model
411
412        model = get_model()
413
414        prev_weight = model[0].weight
415        prev_B = model[0].parametrizations.weight[0].B
416
417        new_model = get_model()
418        with TemporaryFileName() as fname:
419            torch.save(model.state_dict(), fname)
420            new_model.load_state_dict(torch.load(fname))
421
422        # Integrity tests
423        self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
424        self.assertEqual(prev_weight, new_model[0].weight)
425        self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)
426
427        # Trying to save the whole parametrized model raises
428        with self.assertRaisesRegex(RuntimeError, "state_dict"):
429            with TemporaryFileName() as fname:
430                torch.save(model, fname)
431
432    # FIXME: Rewrite this test using functions not depending on LAPACK
433    #        and remove the `@skipIfNoLapack` (see #70995)
434    @skipIfNoLapack
435    @swap([True, False])
436    def test_initialization_parametrization(self):
437        r"""Test that it is possible to initialize a parametrization when it
438        implements a `right_inverse` method
439        """
440
441        class Skew(nn.Module):
442            def forward(self, X):
443                A = X.triu(1)
444                return A - A.T
445
446            def is_skew(self, A):
447                return torch.allclose(A, -A.T, atol=1e-6)
448
449            def right_inverse(self, X):
450                if not self.is_skew(X):
451                    raise ValueError("The matrix is not skew-symmetric.")
452                return X.triu(1)
453
454        # Implements a Cayley map where right_inverse is not quite the inverse of forward
455        class Orthogonal(nn.Module):
456            def __init__(self, n):
457                super().__init__()
458                self.B = Buffer(torch.eye(n))
459
460            def forward(self, X):
461                Id = torch.eye(X.size(0))
462                return self.B @ torch.linalg.solve(Id + X, Id - X)
463
464            def is_orthogonal(self, X):
465                Id = torch.eye(X.size(0))
466                return torch.allclose(X.T @ X, Id, atol=1e-4)
467
468            def right_inverse(self, X):
469                if not self.is_orthogonal(X):
470                    raise ValueError("The input is not orthogonal.")
471                # cayley(0) == Id, so B @ cayley(0) == B
472                self.B = X
473                return torch.zeros_like(X)
474
475        N = 5
476        model = nn.Linear(N, N)
477        # Register the skew-symmetric constraint. The result is now skew-symmetric
478        skew = Skew()
479        # Make the weight skew-symmetric before registering the parametrization
480        with torch.no_grad():
481            model.weight.set_(skew(model.weight))
482        parametrize.register_parametrization(model, "weight", skew)
483        X = torch.rand(N, N)
484        # X is not skew-symmetric, so it throws an error
485        with self.assertRaises(ValueError):
486            model.weight = X
487        # Make X skew-symmetric
488        X = X - X.T
489        model.weight = X
490        self.assertEqual(model.parametrizations.weight.original, X.triu(1))
491        self.assertEqual(model.weight, X)
492
493        # Having several parametrizations registered should work in the same way
494        parametrize.register_parametrization(model, "weight", Orthogonal(N))
495        # Register now the Cayley map. The result is now orthogonal
496        X = torch.rand(N, N)
497        # X is not orthogonal, so it throws an error
498        with self.assertRaises(ValueError):
499            model.weight = X
500        init.orthogonal_(X)
501        model.weight = X
502        self.assertEqual(model.weight, X)
503        self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
504
505    @swap([True, False])
506    def test_errors_unparametrized_tensor_parametrization(self):
507        # Test errors when registering a parametrization on an unparametrized tensor
508        module = nn.Linear(3, 4)
509        weight_init = module.weight.clone()
510
511        class Identity(nn.Module):
512            def forward(self, x):
513                return x
514
515        # Register a parametrization on a non-existing parameter throws
516        with self.assertRaisesRegex(ValueError, "does not have a parameter"):
517            parametrize.register_parametrization(module, "foo", Identity())
518        self.assertFalse(parametrize.is_parametrized(module))
519
520        # Removing parametrizations from an unparametrized tensor throws
521        with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
522            parametrize.remove_parametrizations(module, "bias")
523        self.assertFalse(parametrize.is_parametrized(module))
524
525        # A correct parametrization with several outputs
526        class Sum(nn.Module):
527            def forward(self, x, y):
528                return x + y
529
530            def right_inverse(self, z):
531                return z, torch.zeros_like(z)
532
533        parametrize.register_parametrization(module, "weight", Sum())
534        # Cannot remove a parametrization with several outputs with `leave_parametrized=False`
535        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
536            parametrize.remove_parametrizations(
537                module, "weight", leave_parametrized=False
538            )
539        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
540
541        # A parametrization with an incorrect number of outputs
542        class WrongNumberParams(nn.Module):
543            def forward(self, x, y, z):
544                return x + y + z
545
546            def right_inverse(self, w):
547                return w, torch.zeros_like(w)
548
549        # Makes param(*param.right_inverse(X)) fail
550        with self.assertRaisesRegex(TypeError, "positional argument"):
551            parametrize.register_parametrization(module, "weight", WrongNumberParams())
552        self.assertFalse(parametrize.is_parametrized(module))
553
554        # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor]
555        class WrongRightInverse(Identity):
556            def right_inverse(self, z):
557                return None
558
559        # right_inverse should return a Tensor or a Sequence[Tensor]
560        with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"):
561            parametrize.register_parametrization(module, "weight", WrongRightInverse())
562        self.assertFalse(parametrize.is_parametrized(module))
563
564        # If it's a sequence, it must to be a sequence of tensors
565        class WrongRightInverseSequence(nn.Module):
566            def forward(self, x, y):
567                return x
568
569            def right_inverse(self, z):
570                return None, z
571
572        with self.assertRaisesRegex(ValueError, "of the sequence with type"):
573            parametrize.register_parametrization(
574                module, "weight", WrongRightInverseSequence()
575            )
576        self.assertFalse(parametrize.is_parametrized(module))
577
578        # A parametrization from one tensor to one tensor that changes the dtype
579        class ChangeDtypeInverse(nn.Module):
580            def forward(self, x):
581                return x.float()
582
583            def right_inverse(self, w):
584                return w.bool()
585
586        # For parametrizations that return one tensor, right_inverse may not change the dtype
587        with self.assertRaisesRegex(
588            ValueError, "outputs one tensor, it may not change the dtype"
589        ):
590            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
591        self.assertFalse(parametrize.is_parametrized(module))
592
593        # Doesn't return a tensor
594        class NotTensor(nn.Module):
595            def forward(self, x):
596                return 2
597
598        # Forward must return a tensor
599        with self.assertRaisesRegex(ValueError, "must return a tensor"):
600            parametrize.register_parametrization(module, "weight", NotTensor())
601        self.assertFalse(parametrize.is_parametrized(module))
602
603        # A parametrization from one tensor to one tensor that changes the dtype
604        class ChangeDtype(nn.Module):
605            def forward(self, x):
606                return x.bool()
607
608        # forward should not change the initial dtype
609        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
610            parametrize.register_parametrization(module, "weight", ChangeDtype())
611        self.assertFalse(parametrize.is_parametrized(module))
612
613        # Change shape
614        class ChangeShape(nn.Module):
615            def forward(self, x):
616                return x[:-1]
617
618        # forward should not change the original shape
619        with self.assertRaisesRegex(ValueError, "may not change the shape"):
620            parametrize.register_parametrization(module, "weight", ChangeShape())
621        self.assertFalse(parametrize.is_parametrized(module))
622
623        # Many to one that changes dtype
624        class ChangeDtypeMulti(nn.Module):
625            def forward(self, x, y):
626                return (x + y).bool()
627
628            def right_inverse(self, w):
629                return w, w + 1
630
631        # forward should not change the original shape even for parametrizations with many inputs
632        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
633            parametrize.register_parametrization(module, "weight", ChangeDtypeMulti())
634        self.assertFalse(parametrize.is_parametrized(module))
635
636        # Returning a sequence of size one, although weird, it's correct
637        class SequenceLen1(nn.Module):
638            def forward(self, x):
639                return x
640
641            def right_inverse(self, w):
642                return (w,)
643
644        parametrize.register_parametrization(module, "weight", SequenceLen1())
645        self.assertTrue(hasattr(module.parametrizations.weight, "original0"))
646        self.assertFalse(hasattr(module.parametrizations.weight, "original1"))
647        _ = module.weight  # Does not throw
648        self.assertTrue(parametrize.is_parametrized(module))
649        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
650
651        # None of the operations above should have altered the weight
652        self.assertFalse(parametrize.is_parametrized(module))
653        self.assertEqual(module.weight, weight_init)
654
655    @swap([True, False])
656    def test_errors_parametrized_tensor_parametrization(self):
657        # Test errors when registering a parametrization on a parametrized tensor
658
659        class Identity(nn.Module):
660            def forward(self, x):
661                return x
662
663        module = nn.Linear(3, 4)
664        parametrize.register_parametrization(module, "weight", Identity())
665
666        # Has to return a tensor
667        class WrongReturn(nn.Module):
668            def forward(self, x):
669                return x, x
670
671        with self.assertRaisesRegex(ValueError, "must return a tensor"):
672            parametrize.register_parametrization(module, "weight", WrongReturn())
673        self.assertTrue(parametrize.is_parametrized(module))
674        self.assertEqual(len(module.parametrizations.weight), 1)
675        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
676
677        # Cannot change dtype
678        class ChangeDtype(nn.Module):
679            def forward(self, x):
680                return x.bool()
681
682        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
683            parametrize.register_parametrization(module, "weight", ChangeDtype())
684        self.assertTrue(parametrize.is_parametrized(module))
685        self.assertEqual(len(module.parametrizations.weight), 1)
686        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
687
688        # Cannot change shape
689        class ChangeShape(nn.Module):
690            def forward(self, x):
691                return x[:-1]
692
693        with self.assertRaisesRegex(ValueError, "may not change the shape"):
694            parametrize.register_parametrization(module, "weight", ChangeShape())
695        self.assertTrue(parametrize.is_parametrized(module))
696        self.assertEqual(len(module.parametrizations.weight), 1)
697        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
698
699        # The following checks are mostly due to bugs in the code of the parametrization
700
701        # right_inverse has to return a tensor
702        class WrongReturnInverse(Identity):
703            def right_inverse(self, x):
704                return x, x
705
706        with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"):
707            parametrize.register_parametrization(module, "weight", WrongReturnInverse())
708        self.assertTrue(parametrize.is_parametrized(module))
709        self.assertEqual(len(module.parametrizations.weight), 1)
710        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
711
712        # Cannot change dtype
713        class ChangeDtypeInverse(Identity):
714            def right_inverse(self, x):
715                return x.bool()
716
717        with self.assertRaisesRegex(ValueError, "must have the same dtype"):
718            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
719        self.assertTrue(parametrize.is_parametrized(module))
720        self.assertEqual(len(module.parametrizations.weight), 1)
721        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
722
723        # Cannot change shape
724        class ChangeShapeInverse(Identity):
725            def right_inverse(self, x):
726                return x[:-1]
727
728        with self.assertRaisesRegex(ValueError, "must have the same shape"):
729            parametrize.register_parametrization(module, "weight", ChangeShapeInverse())
730        self.assertTrue(parametrize.is_parametrized(module))
731        self.assertEqual(len(module.parametrizations.weight), 1)
732        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
733
734    # FIXME: Rewrite this test using functions not depending on LAPACK
735    #        and remove the `@skipIfNoLapack` (see #70995)
736    @skipIfNoLapack
737    @swap([True, False])
738    def test_multiple_inputs_parametrization(self):
739        # A parametrization with several outputs
740        class RankOne(nn.Module):
741            def forward(self, x, y):
742                # Form a rank-1 matrix from a pair of vectors
743                return x.unsqueeze(-1) @ y.unsqueeze(-2)
744
745            def right_inverse(self, Y):
746                # We project the given matrix onto the rank 1 matrices
747                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
748                # S is ordered in a decreasing way.
749                s0_sqrt = S[0].sqrt().unsqueeze(-1)
750                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
751
752        # Simple parametrisation
753        class Double(nn.Module):
754            def forward(self, x):
755                return 2.0 * x
756
757            def right_inverse(self, w):
758                return 0.5 * w
759
760        model = nn.Linear(3, 3)
761        # Test one parametrization
762        parametrize.register_parametrization(model, "weight", RankOne())
763        self.assertTrue(hasattr(model, "parametrizations"))
764        self.assertTrue(parametrize.is_parametrized(model))
765        self.assertTrue(parametrize.is_parametrized(model, "weight"))
766        self.assertTrue(hasattr(model.parametrizations.weight, "original0"))
767        self.assertIn("original0", model.parametrizations.weight._parameters)
768        self.assertTrue(hasattr(model.parametrizations.weight, "original1"))
769        self.assertIn("original1", model.parametrizations.weight._parameters)
770        self.assertFalse(parametrize.is_parametrized(model, "bias"))
771        self.assertNotIn("weight", model._parameters)
772        # Result should be rank 1
773        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
774
775        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
776            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
777            parametrize.remove_parametrizations(
778                model, "weight", leave_parametrized=False
779            )
780        # Remove parametrization and check consistency
781        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
782        self.assertFalse(hasattr(model, "parametrizations"))
783        self.assertEqual(model.__class__, nn.Linear)
784        self.assertFalse(parametrize.is_parametrized(model))
785        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
786        self.assertIn("weight", model._parameters)
787
788        # Registering parametrizations with one input on top of one with multiple inputs should work
789        init_weight = model.weight.clone()
790        parametrize.register_parametrization(model, "weight", RankOne())
791        # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix
792        self.assertEqual(init_weight, model.weight)
793        parametrize.register_parametrization(model, "weight", Double())
794        # The matrix now is twice the initial matrix
795        self.assertEqual(2.0 * init_weight, model.weight)
796        # Multiplying by a scalar does not change the rank
797        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
798
799        # The model has now three parameters
800        self.assertEqual(len(list(model.parameters())), 3)
801
802        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
803
804        # Test backward. Should not throw
805        for _ in range(2):
806            sgd.zero_grad()
807            loss = (model.weight.T @ model.bias).sum()
808            loss.backward()
809            sgd.step()
810
811        # Same drill as before, removing should work as expected
812        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
813            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
814            parametrize.remove_parametrizations(
815                model, "weight", leave_parametrized=False
816            )
817        # Remove parametrization and check consistency
818        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
819        self.assertFalse(hasattr(model, "parametrizations"))
820        self.assertEqual(model.__class__, nn.Linear)
821        self.assertFalse(parametrize.is_parametrized(model))
822        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
823        self.assertIn("weight", model._parameters)
824
825        # The model has now two parameters
826        self.assertEqual(len(list(model.parameters())), 2)
827
828        # Test backward. Should not throw
829        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
830        for _ in range(2):
831            sgd.zero_grad()
832            loss = (model.weight.T @ model.bias).sum()
833            loss.backward()
834            sgd.step()
835
836    # FIXME: Rewrite this test using functions not depending on LAPACK
837    #        and remove the `@skipIfNoLapack` (see #70995)
838    @skipIfNoLapack
839    @swap([True, False])
840    def test_caching_parametrization(self):
841        r"""Test the caching system of a parametrization"""
842
843        # Define a couple matrix parametrizations
844        class Skew(nn.Module):
845            def forward(self, X):
846                X = X.tril(-1)
847                return X - X.T
848
849        class Orthogonal(nn.Module):
850            def forward(self, X):
851                Id = torch.eye(X.size(0), device=X.device)
852                return torch.linalg.solve(Id + X, Id - X)
853
854        model = nn.Linear(5, 5)
855        parametrize.register_parametrization(model, "weight", Skew())
856        parametrize.register_parametrization(model, "weight", Orthogonal())
857
858        # Test that the caching system works
859        with parametrize.cached():
860            X = model.weight
861            Y = model.weight
862            self.assertEqual(id(X), id(Y))
863
864    # FIXME: Rewrite this test using functions not depending on LAPACK
865    #        and remove the `@skipIfNoLapack` (see #70995)
866    @skipIfNoLapack
867    @swap([True, False])
868    def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
869        r"""Test that transferring parametrizations doesn't cause issues with caching"""
870
871        class Skew(nn.Module):
872            def forward(self, X):
873                X = X.tril(-1)
874                return X - X.T
875
876        class Orthogonal(nn.Module):
877            def forward(self, X):
878                Id = torch.eye(X.size(0), device=X.device)
879                return torch.linalg.solve(Id + X, Id - X)
880
881        model = nn.Linear(5, 5)
882        parametrize.register_parametrization(model, "weight", Skew())
883        parametrize.register_parametrization(model, "weight", Orthogonal())
884
885        to_model = nn.Linear(5, 5)
886        parametrize.transfer_parametrizations_and_params(model, to_model)
887
888        with parametrize.cached():
889            X = model.weight
890            Y = model.weight
891            self.assertEqual(id(X), id(Y))
892
893            A = to_model.weight
894            B = to_model.weight
895            self.assertEqual(id(A), id(B))
896
897            # test that the results are distinct objects for each module
898            self.assertNotEqual(id(A), id(X))
899
900    @swap([True, False])
901    def test_parametrization_same_training_mode(self):
902        r"""Test training mode updated on parametrization registration"""
903
904        class Identity(nn.Module):
905            def forward(self, X):
906                return X
907
908        module = nn.Linear(4, 4)
909        module.eval()
910        parametrize.register_parametrization(module, "weight", Identity())
911        self.assertFalse(module.parametrizations.weight[0].training)
912        module.train()
913        parametrize.register_parametrization(module, "weight", Identity().eval())
914        self.assertTrue(module.parametrizations.weight[0].training)
915        self.assertTrue(module.parametrizations.weight[1].training)
916
917    @swap([True, False])
918    def test_type_before_parametrizations(self):
919        r"""Test that type_before_parametrizations always retrieves original type"""
920
921        class Identity(nn.Module):
922            def forward(self, X):
923                return X
924
925        model = nn.Linear(5, 5)
926        original_type = type(model)
927        self.assertTrue(
928            parametrize.type_before_parametrizations(model) == original_type
929        )
930        parametrize.register_parametrization(model, "weight", Identity())
931        self.assertTrue(
932            parametrize.type_before_parametrizations(model) == original_type
933        )
934
935    @swap([True, False])
936    def test_deepcopy_after_parametrization(self):
937        r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
938
939        class AddOne(nn.Module):
940            def forward(self, x):
941                return x + 1.0
942
943        class ModelWithoutDeepcopy(nn.Module):
944            def __init__(self) -> None:
945                super().__init__()
946                self.weight = nn.Parameter(
947                    torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True
948                )
949                self.bias = nn.Parameter(
950                    torch.tensor([0.0, 0.0, 0.0, 0.0]), requires_grad=True
951                )
952                self.attr = [1.0, 2.0, 3.0, 4.0]
953
954        class ActualModel(ModelWithoutDeepcopy):
955            # Emulate custom implementation of the deepcopying.
956            def __deepcopy__(self, memo):
957                result = self.__new__(self.__class__)
958                memo[id(self)] = result
959                result.__dict__ = deepcopy(self.__dict__, memo)
960                return result
961
962        def check_deepcopy(m1: nn.Module, m2: nn.Module):
963            w1 = m1.parametrizations.weight.original
964            w2 = m2.parametrizations.weight.original
965            b1 = (
966                m1.parametrizations.bias.original
967                if parametrize.is_parametrized(m1, "bias")
968                else m1.bias
969            )
970            b2 = (
971                m2.parametrizations.bias.original
972                if parametrize.is_parametrized(m2, "bias")
973                else m2.bias
974            )
975            # Weights, biases and attributes should be equal but they must be different objects.
976            self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
977            self.assertIsNot(m1, m2)
978            self.assertEqual(w1, w2)
979            self.assertIsNot(w1, w2)
980            self.assertEqual(b1, b2)
981            self.assertIsNot(b1, b2)
982            self.assertEqual(m1.attr, m2.attr)
983            self.assertIsNot(m1.attr, m2.attr)
984
985        for model in (ModelWithoutDeepcopy(), ActualModel()):
986            # General check that we are able to create deepcopy.
987            parametrize.register_parametrization(model, "weight", AddOne())
988            check_deepcopy(model, deepcopy(model))
989            # Check that this works on models with several parametrized tensors.
990            parametrize.register_parametrization(model, "bias", AddOne())
991            check_deepcopy(model, deepcopy(model))
992            # Check that this works on models where tensors have more than one parametrization.
993            parametrize.register_parametrization(model, "weight", AddOne())
994            check_deepcopy(model, deepcopy(model))
995
996    @swap([True, False])
997    def test_transfer_parametrizations_and_params(self):
998        r"""Test that all parametrizations and their associated parameters are transferred."""
999
1000        class AddOne(nn.Module):
1001            def forward(self, x):
1002                return x + 1.0
1003
1004        class Double(nn.Module):
1005            def forward(self, x):
1006                return 2.0 * x
1007
1008            def right_inverse(self, x):
1009                return 0.5 * x
1010
1011        class MinusOne(nn.Module):
1012            def forward(self, x):
1013                return x - 1.0
1014
1015        model = nn.Linear(5, 5)
1016        parametrize.register_parametrization(model, "weight", AddOne())
1017        parametrize.register_parametrization(model, "weight", Double())
1018        parametrize.register_parametrization(model, "weight", MinusOne())
1019        hold_weight = model.weight
1020
1021        to_model = torch.ao.nn.qat.Linear(
1022            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1023        )
1024        parametrize.transfer_parametrizations_and_params(model, to_model)
1025
1026        # checks that final and original value are correct and the to_model is parametrized
1027        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1028        self.assertEqual(model.weight, to_model.weight)
1029        self.assertEqual(
1030            model.parametrizations.weight.original,
1031            to_model.parametrizations.weight.original,
1032        )
1033
1034        # check that the transfer didn't affect the original value
1035        self.assertEqual(hold_weight, model.weight)
1036        if get_swap_module_params_on_conversion():
1037            # When using the swap_tensors path, this is needed so that the autograd
1038            # graph is not alive anymore.
1039            del hold_weight
1040
1041        # testing that changes to one set of parametrizations do not affect the other
1042        parametrize.remove_parametrizations(to_model, "weight")
1043        self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1044        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight"))
1045
1046        # also test that parameters that don't exist in to_model get transferred
1047        model.test_param = Parameter(torch.randn(5, 5))
1048
1049        self.assertTrue(not hasattr(to_model, "test_param"))
1050        parametrize.register_parametrization(model, "test_param", Double())
1051        hold_test_param = model.test_param
1052        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1053
1054        # check that previously missing params got transferred correctly
1055        self.assertEqual(model.test_param, to_model.test_param)
1056        self.assertEqual(
1057            model.parametrizations.test_param.original,
1058            to_model.parametrizations.test_param.original,
1059        )
1060
1061        # check that the new transfer didn't change the value for the from_module
1062        self.assertEqual(hold_test_param, model.test_param)
1063
1064    @swap([True, False])
1065    def test_transfer_parametrizations_and_params_right_inverse(self):
1066        r"""Test that all parametrizations and their associated parameters are transferred."""
1067
1068        class Double(nn.Module):
1069            def forward(self, x):
1070                return 2.0 * x
1071
1072            def right_inverse(self, x):
1073                return 0.5 * x
1074
1075        model = nn.Linear(5, 5)
1076        parametrize.register_parametrization(model, "weight", Double())
1077        hold_weight = model.weight
1078
1079        to_model = torch.ao.nn.qat.Linear(
1080            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1081        )
1082        parametrize.transfer_parametrizations_and_params(model, to_model)
1083
1084        # check that transfer occurs successfully
1085        self.assertEqual(model.weight, to_model.weight)
1086        self.assertEqual(
1087            model.parametrizations.weight.original,
1088            to_model.parametrizations.weight.original,
1089        )
1090
1091        # check that transfer doesn't affect the from_model weight
1092        self.assertEqual(hold_weight, model.weight)
1093
1094    @swap([True, False])
1095    def test_transfer_parametrizations_and_params_single_param(self):
1096        r"""Test that all parametrizations and their associated parameters are transferred."""
1097
1098        class AddOne(nn.Module):
1099            def forward(self, x):
1100                return x + 1.0
1101
1102        class Double(nn.Module):
1103            def forward(self, x):
1104                return 2.0 * x
1105
1106        class MinusOne(nn.Module):
1107            def forward(self, x):
1108                return x - 1.0
1109
1110        model = nn.Linear(5, 5, bias=True)
1111        parametrize.register_parametrization(model, "weight", AddOne())
1112        parametrize.register_parametrization(model, "weight", Double())
1113        parametrize.register_parametrization(model, "weight", MinusOne())
1114        parametrize.register_parametrization(model, "bias", AddOne())
1115        parametrize.register_parametrization(model, "bias", Double())
1116        parametrize.register_parametrization(model, "bias", MinusOne())
1117
1118        to_model = torch.ao.nn.qat.Linear(
1119            5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig()
1120        )
1121        parametrize.transfer_parametrizations_and_params(model, to_model, "weight")
1122
1123        # check that weight and only weight was transferred
1124        self.assertEqual(model.weight, to_model.weight)
1125        self.assertEqual(
1126            model.parametrizations.weight.original,
1127            to_model.parametrizations.weight.original,
1128        )
1129        self.assertTrue("bias" not in to_model.parametrizations)
1130
1131    # FIXME: Rewrite this test using functions not depending on LAPACK
1132    # and remove the `@skipIfNoLapack` (see #70995)
1133    @skipIfNoLapack
1134    @swap([True, False])
1135    def test_transfer_parametrizations_and_params_many_to_one(self):
1136        # A parametrization with several outputs
1137        class RankOne(nn.Module):
1138            def forward(self, x, y):
1139                # Form a rank-1 matrix from a pair of vectors
1140                return x.unsqueeze(-1) @ y.unsqueeze(-2)
1141
1142            def right_inverse(self, Y):
1143                # We project the given matrix onto the rank 1 matrices
1144                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
1145                # S is ordered in a decreasing way.
1146                s0_sqrt = S[0].sqrt().unsqueeze(-1)
1147                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
1148
1149        class Double(nn.Module):
1150            def forward(self, x):
1151                return 2.0 * x
1152
1153        model = nn.Linear(3, 3)
1154        parametrize.register_parametrization(model, "weight", RankOne())
1155        parametrize.register_parametrization(model, "weight", Double())
1156        hold_weight = model.weight
1157
1158        to_model = torch.ao.nn.qat.Linear(
1159            3, 3, qconfig=torch.ao.quantization.get_default_qconfig()
1160        )
1161
1162        parametrize.transfer_parametrizations_and_params(model, to_model)
1163
1164        # checks that final and original value are correct and the to_model is parametrized
1165        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1166        self.assertEqual(model.weight, to_model.weight)
1167        self.assertEqual(
1168            model.parametrizations.weight.original0,
1169            to_model.parametrizations.weight.original0,
1170        )
1171        self.assertEqual(
1172            model.parametrizations.weight.original1,
1173            to_model.parametrizations.weight.original1,
1174        )
1175
1176        # check that the transfer didn't affect the original value
1177        self.assertEqual(hold_weight, model.weight)
1178
1179        # testing that changes to one set of parametrizations do not affect the other
1180        model.test_param = Parameter(torch.randn(3, 3))
1181
1182        self.assertTrue(not hasattr(to_model, "test_param"))
1183        parametrize.register_parametrization(model, "test_param", RankOne())
1184        hold_test_param = model.test_param
1185        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1186
1187        # also check that previously missing params got transferred correctly
1188        self.assertEqual(model.test_param, to_model.test_param)
1189        self.assertEqual(
1190            model.parametrizations.test_param.original0,
1191            to_model.parametrizations.test_param.original0,
1192        )
1193        self.assertEqual(
1194            model.parametrizations.test_param.original1,
1195            to_model.parametrizations.test_param.original1,
1196        )
1197
1198        # check that the new transfer didn't change the value for the from_module
1199        self.assertEqual(hold_test_param, model.test_param)
1200
1201    @swap([True, False])
1202    def test_new_spectral_norm(self):
1203        with set_default_dtype(torch.double):
1204            input = torch.randn(3, 5)
1205            m = nn.Linear(5, 7)
1206            m = torch.nn.utils.parametrizations.spectral_norm(m)
1207            spectral_norm_m = m.parametrizations.weight[0]
1208
1209            self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)]))
1210
1211            # .parametrizations.weight.original should be trainable
1212            self.assertTrue(hasattr(m.parametrizations.weight, "original"))
1213            self.assertTrue("original" in m.parametrizations.weight._parameters)
1214
1215            # u should be just a reused buffer
1216            self.assertTrue(hasattr(spectral_norm_m, "_u"))
1217            self.assertTrue("_u" in spectral_norm_m._buffers)
1218            self.assertTrue("_v" in spectral_norm_m._buffers)
1219
1220            # weight should be a plain attribute, not counted as a buffer or a param
1221            self.assertIsNotNone(m.weight)
1222            self.assertFalse("weight" in m._buffers)
1223            self.assertFalse("weight" in m._parameters)
1224
1225            # it should also be sharing storage as `weight_orig`
1226            # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage())
1227            self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size())
1228            self.assertEqual(
1229                m.parametrizations.weight.original.stride(), m.weight.stride()
1230            )
1231
1232            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1233
1234            # spectral_norm is the only parametrization
1235            self.assertFalse(hasattr(m, "parametrizations"))
1236            self.assertTrue("weight" in m._parameters)
1237
1238            # We can register spectral_norm multiple times on the same parameter
1239            # and on multiple parameters in the same module
1240            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1241            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1242            m = torch.nn.utils.parametrizations.spectral_norm(m, "bias")
1243
1244            # If we remove the parametrization on bias, weight is still parametrized
1245            # Removing a parametrization runs forward in eval mode if leave_parametrized=True
1246            m = torch.nn.utils.parametrize.remove_parametrizations(m, "bias")
1247            self.assertTrue("bias" in m._parameters)
1248            self.assertTrue(hasattr(m, "parametrizations"))
1249            self.assertFalse("weight" in m._parameters)
1250
1251            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1252            # Neither weight and bias are parametrized
1253            self.assertFalse(hasattr(m, "parametrizations"))
1254            self.assertTrue("weight" in m._parameters)
1255            self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m))
1256
1257            # test correctness in training/eval modes and cpu/multi-gpu settings
1258            for apply_dp in (True, False):
1259                if apply_dp:
1260                    if not TEST_MULTIGPU:
1261                        continue
1262                    device = torch.device("cuda:0")
1263
1264                    def maybe_wrap(m):
1265                        return torch.nn.DataParallel(m, [0, 1])
1266
1267                else:
1268                    device = torch.device("cpu")
1269
1270                    def maybe_wrap(m):
1271                        return m
1272
1273                for requires_grad in (True, False):
1274
1275                    def get_modules():
1276                        m = nn.Linear(3, 4).to(device)
1277                        m.weight.requires_grad_(requires_grad)
1278                        m = torch.nn.utils.parametrizations.spectral_norm(m)
1279                        wrapped_m = maybe_wrap(m)
1280                        spectral_norm_m = m.parametrizations.weight[0]
1281                        return m, wrapped_m, spectral_norm_m
1282
1283                    input = torch.randn(2, 3, device=device)
1284
1285                    m, wrapped_m, spectral_norm_m = get_modules()
1286
1287                    self.assertTrue(hasattr(spectral_norm_m, "_u"))
1288                    u0 = spectral_norm_m._u.clone()
1289                    v0 = spectral_norm_m._v.clone()
1290
1291                    # TEST TRAINING BEHAVIOR
1292
1293                    # We perform GD first to modify the initial matrix
1294                    opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1)
1295
1296                    opt.zero_grad()
1297                    wrapped_m(input).sum().backward()
1298                    opt.step()
1299
1300                    out = wrapped_m(input)
1301                    if requires_grad:
1302                        # run forward again and assert that u and v are updated
1303                        self.assertNotEqual(u0, spectral_norm_m._u)
1304                        self.assertNotEqual(v0, spectral_norm_m._v)
1305
1306                    # assert that backprop reaches original weight
1307                    # can't use gradcheck because the function changes as we
1308                    # activate through it in training mode
1309                    if requires_grad:
1310                        torch.autograd.grad(
1311                            out.sum(), m.parametrizations.weight.original
1312                        )
1313
1314                    # test backward works with multiple forwards
1315                    # it uses training mode so we need to reset `u` and `v` vectors
1316                    # to same value at beginning for finite difference test to pass
1317                    saved_u = spectral_norm_m._u.clone()
1318                    saved_v = spectral_norm_m._v.clone()
1319
1320                    def fn(input):
1321                        spectral_norm_m._u.data.copy_(saved_u)
1322                        spectral_norm_m._v.data.copy_(saved_v)
1323                        out0 = wrapped_m(input)
1324                        out1 = wrapped_m(input)
1325                        return out0 + out1
1326
1327                    # Make sure we can compute gradients wrt to all the parameters in the case
1328                    # of double forward
1329                    fn(input.clone().requires_grad_()).sum().backward()
1330                    gradcheck(
1331                        fn, (input.clone().requires_grad_(),), check_batched_grad=False
1332                    )
1333
1334                    # test removing
1335                    # spectral norm module needs to be in eval mode if we'd like to
1336                    # avoid doing another power iteration
1337                    m, wrapped_m, _ = get_modules()
1338                    pre_remove_out = wrapped_m(input)
1339                    if get_swap_module_params_on_conversion():
1340                        # When using the swap_tensors path, this is needed so that the autograd
1341                        # graph is not alive anymore.
1342                        pre_remove_out_ref = pre_remove_out.detach()
1343                        del pre_remove_out
1344                    else:
1345                        pre_remove_out_ref = pre_remove_out
1346                    m.eval()
1347                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1348                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1349
1350                    torch.nn.utils.parametrizations.spectral_norm(m)
1351                    for _ in range(3):
1352                        pre_remove_out = wrapped_m(input)
1353                    if get_swap_module_params_on_conversion():
1354                        # When using the swap_tensors path, this is needed so that the autograd
1355                        # graph is not alive anymore.
1356                        pre_remove_out_ref = pre_remove_out.detach()
1357                        del pre_remove_out
1358                    else:
1359                        pre_remove_out_ref = pre_remove_out
1360                    m.eval()
1361                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1362                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1363
1364                    # TEST EVAL BEHAVIOR
1365                    m, wrapped_m, spectral_norm_m = get_modules()
1366                    wrapped_m(input)
1367                    last_train_out = wrapped_m(input)
1368                    last_train_u = spectral_norm_m._u.clone()
1369                    last_train_v = spectral_norm_m._v.clone()
1370                    wrapped_m.zero_grad()
1371                    wrapped_m.eval()
1372
1373                    eval_out0 = wrapped_m(input)
1374                    # assert eval gives same result as last training iteration
1375                    self.assertEqual(eval_out0, last_train_out)
1376                    # assert doing more iteartion in eval don't change things
1377                    self.assertEqual(eval_out0, wrapped_m(input))
1378                    self.assertEqual(last_train_u, spectral_norm_m._u)
1379                    self.assertEqual(last_train_v, spectral_norm_m._v)
1380
1381                    # FIXME: the code below is flaky when executed with DataParallel
1382                    # see https://github.com/pytorch/pytorch/issues/13818
1383                    if apply_dp:
1384                        continue
1385
1386                    # test backward works with multiple forwards in mixed training
1387                    # and eval modes
1388                    # it uses training mode so we need to reset `u` and `v` vectors
1389                    # to same value at beginning for finite difference test to pass
1390                    saved_u = spectral_norm_m._u.clone()
1391                    saved_v = spectral_norm_m._v.clone()
1392
1393                    def fn(input):
1394                        spectral_norm_m._u.data.copy_(saved_u)
1395                        spectral_norm_m._v.data.copy_(saved_v)
1396                        wrapped_m.train()
1397                        out0 = wrapped_m(input)
1398                        wrapped_m.eval()
1399                        out1 = wrapped_m(input)
1400                        wrapped_m.train()
1401                        out2 = wrapped_m(input)
1402                        wrapped_m.eval()
1403                        out3 = wrapped_m(input)
1404                        return out0 + out1 + out2 + out3
1405
1406                    gradcheck(fn, (input.clone().requires_grad_(),))
1407
1408                    # assert that backprop reaches weight_orig in eval
1409                    if requires_grad:
1410
1411                        def fn(weight):
1412                            return wrapped_m(input)
1413
1414                        gradcheck(fn, (m.parametrizations.weight.original,))
1415
1416    def test_register_parametrization_no_grad(self):
1417        r"""Test that it is possible to register a parametrization without gradient"""
1418
1419        class SplitAndCat(nn.Module):
1420            def right_inverse(self, x):
1421                # split the tensor in two halfs
1422                return torch.split(x, x.shape[1] // 2)
1423
1424            def forward(self, x0, x1):
1425                return torch.cat([x0, x1])
1426
1427        model = nn.Linear(8, 8)
1428
1429        model.weight.requires_grad = False
1430        parametrize.register_parametrization(model, "weight", SplitAndCat())
1431        # making sure the parameterized and decomposed Tensors both have requires_grad == False
1432        self.assertFalse(model.weight.requires_grad)
1433        self.assertFalse(model.parametrizations.weight.original0.requires_grad)
1434        self.assertFalse(model.parametrizations.weight.original1.requires_grad)
1435
1436    @swap([True, False])
1437    def test_new_spectral_norm_load_state_dict(self):
1438        for activate_times in (0, 3):
1439            inp = torch.randn(2, 3)
1440            m = nn.Linear(3, 5)
1441            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1442            snm.train()
1443
1444            for _ in range(activate_times):
1445                snm(inp)
1446
1447            state_dict = deepcopy(snm.state_dict())
1448            self.assertEqual(
1449                {
1450                    "parametrizations.weight.original",
1451                    "bias",
1452                    "parametrizations.weight.0._v",
1453                    "parametrizations.weight.0._u",
1454                },
1455                set(state_dict.keys()),
1456            )
1457
1458            # test that non-strict loading works
1459            non_strict_state_dict = deepcopy(state_dict)
1460            non_strict_state_dict["nonsense"] = "nonsense"
1461            with self.assertRaisesRegex(
1462                RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'
1463            ):
1464                snm.load_state_dict(non_strict_state_dict, strict=True)
1465            snm.load_state_dict(non_strict_state_dict, strict=False)
1466            del non_strict_state_dict["parametrizations.weight.original"]
1467            snm.load_state_dict(non_strict_state_dict, strict=False)
1468            del non_strict_state_dict["parametrizations.weight.0._u"]
1469            snm.load_state_dict(non_strict_state_dict, strict=False)
1470            del non_strict_state_dict["parametrizations.weight.0._v"]
1471            snm.load_state_dict(non_strict_state_dict, strict=False)
1472            non_strict_state_dict[
1473                "weight"
1474            ] = snm.weight.detach().clone()  # set W as a buffer
1475            snm.load_state_dict(non_strict_state_dict, strict=False)
1476            del non_strict_state_dict._metadata[
1477                "parametrizations.weight.0"
1478            ]  # remove metadata info
1479            snm.load_state_dict(non_strict_state_dict, strict=False)
1480            del non_strict_state_dict["weight"]  # remove W buffer
1481            snm.load_state_dict(non_strict_state_dict, strict=False)
1482            del non_strict_state_dict["bias"]
1483            snm.load_state_dict(non_strict_state_dict, strict=False)
1484
1485            # normal state_dict
1486
1487            # test that re-wrapping does not matter
1488            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1489            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1490
1491            snm.load_state_dict(state_dict)
1492            with torch.no_grad():
1493                snm.eval()
1494                out0_eval = snm(inp)
1495                snm.train()
1496                out1_train = snm(inp)
1497                out2_train = snm(inp)
1498                snm.eval()
1499                out3_eval = snm(inp)
1500
1501            # test that re-wrapping does not matter
1502            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1503            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1504
1505            # Test normal loading
1506            snm.load_state_dict(state_dict)
1507            with torch.no_grad():
1508                snm.eval()
1509                self.assertEqual(out0_eval, snm(inp))
1510                snm.train()
1511                self.assertEqual(out1_train, snm(inp))
1512                self.assertEqual(out2_train, snm(inp))
1513                snm.eval()
1514                self.assertEqual(out3_eval, snm(inp))
1515
1516    @swap([True, False])
1517    def test_new_spectral_norm_dim(self):
1518        inp = torch.randn(2, 3, 10, 12)
1519        m = nn.ConvTranspose2d(3, 4, (5, 6))
1520        m = torch.nn.utils.parametrizations.spectral_norm(m)
1521        snm = m.parametrizations.weight[0]
1522        # this should not run into incompatible shapes
1523        x = m(inp)
1524        # check that u refers to the same dimension
1525        self.assertEqual(
1526            snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape
1527        )
1528
1529    @swap([True, False])
1530    def test_new_spectral_norm_forward(self):
1531        input = torch.randn(3, 5)
1532        m = nn.Linear(5, 7)
1533        m = torch.nn.utils.parametrizations.spectral_norm(m)
1534        snm = m.parametrizations.weight[0]
1535        # naive forward
1536        _weight = m.parametrizations.weight.original
1537        _bias, _v = m.bias, snm._v
1538        _weight_mat = _weight.view(_weight.size(0), -1)
1539        _u = torch.mv(_weight_mat, _v)
1540        _u = F.normalize(_u, dim=0, eps=1e-12)
1541        _v = torch.mv(_weight_mat.t(), _u)
1542        _v = F.normalize(_v, dim=0, eps=1e-12)
1543        _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
1544        out_hat = torch.nn.functional.linear(input, _weight, _bias)
1545        expect_out = m(input)
1546        self.assertEqual(expect_out, out_hat)
1547
1548    @swap([True, False])
1549    @skipIfTorchDynamo("Test does not work with TorchDynamo")
1550    def test_new_spectral_norm_value(self):
1551        # a test that the spectral norm (= top singular value)
1552        # is in fact properly calculated, using example of a simple diagonal matrix.
1553        for dtype in (torch.float, torch.cfloat):
1554            m = nn.Linear(2, 2, dtype=dtype)
1555            with torch.no_grad():
1556                # set weight to be diagonal
1557                x = torch.diagonal(m.weight)
1558                m.weight = nn.Parameter(torch.diag(x))
1559                torch.nn.utils.parametrizations.spectral_norm(m)
1560                # weights should be rescaled by spectral norm, (i.e., largest diagonal element in norm)
1561                expected = torch.diag(x / x.abs().max())
1562                self.assertEqual(m.weight.data, expected)
1563
1564    @skipIfNoLapack
1565    @swap([True, False])
1566    def test_orthogonal_parametrization(self):
1567        # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
1568
1569        def assert_is_orthogonal(X):
1570            n, k = X.size(-2), X.size(-1)
1571            if n < k:
1572                X = X.mT
1573                n, k = k, n
1574            Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(
1575                *(X.size()[:-2]), k, k
1576            )
1577            eps = 10 * n * torch.finfo(X.dtype).eps
1578            torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.0)
1579
1580        def assert_weight_allclose_Q(weight, W):
1581            # Test that weight is equal to the Q part of the QR decomposition of W
1582            # (or of its transpose if the matrix is wide)
1583            wide_matrix = W.size(-2) < W.size(-1)
1584            if wide_matrix:
1585                W = W.mT
1586            Q, R = torch.linalg.qr(W)
1587            Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
1588            if wide_matrix:
1589                Q = Q.mT
1590            torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.0)
1591
1592        for shape, dtype, use_linear in product(
1593            ((4, 4), (5, 3), (3, 5)),  # square/ tall / wide
1594            (torch.float32, torch.complex64),
1595            (True, False),
1596        ):
1597            # Conv2d does not support complex yet
1598            if not use_linear:
1599                continue
1600
1601            if use_linear:
1602                input = torch.randn(3, shape[0], dtype=dtype)
1603            else:
1604                input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)
1605
1606            for parametrization, use_trivialization in product(
1607                ("matrix_exp", "cayley", "householder"), (False, True)
1608            ):
1609                # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False
1610                # See Note [right_inverse expm cayley]
1611                can_initialize = use_trivialization or parametrization == "householder"
1612
1613                # We generate them every time to always start with fresh weights
1614                if use_linear:
1615                    m = nn.Linear(*shape, dtype=dtype)
1616                else:
1617                    m = nn.Conv2d(2, 3, shape, dtype=dtype)
1618
1619                # We do not support householder for complex inputs
1620                # See Note [Householder complex]
1621
1622                # When using the swap_tensors path, this is needed so that the autograd
1623                # graph is not alive anymore.
1624                if get_swap_module_params_on_conversion():
1625                    w_init = m.weight.clone().detach()
1626                else:
1627                    w_init = m.weight.clone()
1628                if parametrization == "householder" and m.weight.is_complex():
1629                    msg = "householder parametrization does not support complex tensors"
1630                    with self.assertRaisesRegex(ValueError, msg):
1631                        torch.nn.utils.parametrizations.orthogonal(
1632                            m,
1633                            "weight",
1634                            parametrization,
1635                            use_trivialization=use_trivialization,
1636                        )
1637                    continue
1638
1639                wide_matrix = w_init.size(-2) < w_init.size(-1)
1640                torch.nn.utils.parametrizations.orthogonal(
1641                    m, "weight", parametrization, use_trivialization=use_trivialization
1642                )
1643                # Forwards works as expected
1644                self.assertEqual(w_init.shape, m.weight.shape)
1645                assert_is_orthogonal(m.weight)
1646                if can_initialize:
1647                    assert_weight_allclose_Q(m.weight, w_init)
1648
1649                # Intializing with a given orthogonal matrix works
1650                X = torch.randn_like(m.weight)
1651                if wide_matrix:
1652                    X = X.mT
1653                w_new = torch.linalg.qr(X).Q
1654                if wide_matrix:
1655                    w_new = w_new.mT
1656                if can_initialize:
1657                    m.weight = w_new
1658                    torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.0)
1659                else:
1660                    msg = (
1661                        "assign to the matrix exponential or the Cayley parametrization"
1662                    )
1663                    with self.assertRaisesRegex(NotImplementedError, msg):
1664                        m.weight = w_new
1665
1666                # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix
1667                w_new = torch.randn_like(m.weight)
1668                if can_initialize:
1669                    m.weight = w_new
1670                    assert_weight_allclose_Q(m.weight, w_new)
1671                else:
1672                    msg = (
1673                        "assign to the matrix exponential or the Cayley parametrization"
1674                    )
1675                    with self.assertRaisesRegex(NotImplementedError, msg):
1676                        m.weight = w_new
1677
1678                opt = torch.optim.SGD(m.parameters(), lr=0.1)
1679                for _ in range(2):
1680                    opt.zero_grad()
1681                    m(input).norm().backward()
1682                    grad = m.parametrizations.weight.original.grad
1683                    self.assertIsNotNone(grad)
1684                    # We do not update the upper triangular part of the matrix if tall tril if wide
1685                    if grad.size(-2) >= grad.size(-1):
1686                        zeros_grad = grad.triu(1)
1687                    else:
1688                        zeros_grad = grad.tril(-1)
1689                    self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
1690                    # The gradient in the diagonal can only be imaginary because a skew-Hermitian
1691                    # matrix has imaginary diagonal
1692                    diag_grad = grad.diagonal(dim1=-2, dim2=-1)
1693                    if grad.is_complex():
1694                        diag_grad = diag_grad.real
1695                    self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
1696                    opt.step()
1697                    assert_is_orthogonal(m.weight)
1698
1699    @skipIfNoLapack
1700    @swap([True, False])
1701    def test_orthogonal_errors(self):
1702        m = nn.Linear(3, 4)
1703        with self.assertRaisesRegex(ValueError, "has to be one of"):
1704            torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")
1705
1706        with self.assertRaisesRegex(ValueError, "Expected a matrix"):
1707            torch.nn.utils.parametrizations.orthogonal(m, "bias")
1708
1709        torch.nn.utils.parametrizations.orthogonal(m, "weight")
1710        with self.assertRaisesRegex(ValueError, "matrices of shape"):
1711            m.weight = torch.randn(5, 5)
1712        torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1713
1714    @swap([True, False])
1715    def test_weight_norm_state_dict_compat(self):
1716        m = nn.Linear(4, 5)
1717        m = torch.nn.utils.weight_norm(m)
1718        old_dict = m.state_dict()
1719
1720        m2 = nn.Linear(4, 5)
1721        m2 = torch.nn.utils.parametrizations.weight_norm(m2)
1722        m2.load_state_dict(old_dict)
1723
1724        input = torch.randn(3, 4)
1725        self.assertEqual(m(input), m2(input))
1726
1727    @swap([True, False])
1728    def test_weight_norm_pickle(self):
1729        m = nn.Linear(4, 5)
1730        m = torch.nn.utils.parametrizations.weight_norm(m)
1731        with self.assertRaisesRegex(RuntimeError, "state_dict"):
1732            pickle.dumps(m)
1733
1734    @swap([True, False])
1735    def test_weight_norm_deepcopy(self):
1736        m = nn.Linear(4, 5)
1737        m = torch.nn.utils.parametrizations.weight_norm(m)
1738        m2 = deepcopy(m)
1739        input = torch.randn(3, 4)
1740        self.assertEqual(m(input), m2(input))
1741
1742    @swap([True])
1743    def test_wrapper_subclass_parametrization(self):
1744        class Subclassify(nn.Module):
1745            def forward(self, X):
1746                return TwoTensor(X, X)
1747
1748        class UnSubclassify(nn.Module):
1749            def forward(self, X):
1750                return X.a
1751
1752        class IdentityWithRightInverse(nn.Module):
1753            def forward(self, X):
1754                return X
1755
1756            def right_inverse(self, X):
1757                return TwoTensor(X, X)
1758
1759        def _check_parametrization(
1760            parametrization,
1761            type_before_registration,
1762            type_after_registration,
1763            leave_parametrized=False,
1764            type_after_right_inverse=None,
1765        ):
1766            model = nn.Linear(2, 2)
1767            buf = torch.randn(2, 2)
1768            model.buf = torch.nn.Buffer(buf)
1769            if (
1770                type_before_registration == TwoTensor
1771                and type_after_registration == Tensor
1772            ):
1773                model._apply(lambda t: TwoTensor(t, t))
1774            initial_weight = model.weight.clone().detach()
1775            initial_weight_id = id(model.weight)
1776            initial_buf = model.buf.clone().detach()
1777            initial_buf_id = id(model.buf)
1778            type_original_weight = (
1779                type_before_registration
1780                if type_after_right_inverse is None
1781                else type_after_right_inverse
1782            )
1783            type_original_buf = (
1784                Tensor if type_original_weight is nn.Parameter else type_original_weight
1785            )
1786            type_after_removal_buf = (
1787                type_after_registration if leave_parametrized else type_original_buf
1788            )
1789            if leave_parametrized:
1790                if type_after_registration is Tensor:
1791                    type_after_removal_weight = nn.Parameter
1792                else:
1793                    type_after_removal_weight = type_after_registration
1794            else:
1795                type_after_removal_weight = type_original_weight
1796
1797            parametrize.register_parametrization(model, "weight", parametrization())
1798            parametrize.register_parametrization(model, "buf", parametrization())
1799            self.assertTrue(hasattr(model, "parametrizations"))
1800            self.assertTrue(parametrize.is_parametrized(model))
1801            self.assertFalse(parametrize.is_parametrized(model, "bias"))
1802            # checks for weight
1803            self.assertTrue(parametrize.is_parametrized(model, "weight"))
1804            self.assertTrue(
1805                isinstance(model.parametrizations.weight.original, nn.Parameter)
1806            )
1807            self.assertTrue(
1808                type(model.parametrizations.weight.original) is type_original_weight
1809            )
1810            self.assertNotIn("weight", model._parameters)
1811            self.assertTrue(type(model.weight) is type_after_registration)
1812            # checks for buf
1813            self.assertTrue(parametrize.is_parametrized(model, "buf"))
1814            self.assertFalse(
1815                isinstance(model.parametrizations.buf.original, nn.Parameter)
1816            )
1817            self.assertTrue(
1818                type(model.parametrizations.buf.original) is type_original_buf
1819            )
1820            self.assertTrue(type(model.buf) is type_after_registration)
1821            parametrize.remove_parametrizations(
1822                model, "weight", leave_parametrized=leave_parametrized
1823            )
1824            parametrize.remove_parametrizations(
1825                model, "buf", leave_parametrized=leave_parametrized
1826            )
1827            self.assertFalse(hasattr(model, "parametrizations"))
1828            self.assertEqual(model.__class__, nn.Linear)
1829            # checks for weight
1830            self.assertTrue(type(model.weight) is type_after_removal_weight)
1831            self.assertTrue(isinstance(model.weight, nn.Parameter))
1832            self.assertEqual(id(model.weight), initial_weight_id)
1833            # checks for buf
1834            self.assertTrue(type(model.buf) is type_after_removal_buf)
1835            self.assertFalse(isinstance(model.buf, nn.Parameter))
1836            self.assertEqual(id(model.buf), initial_buf_id)
1837            if not leave_parametrized and type_after_right_inverse is None:
1838                self.assertEqual(model.weight, initial_weight)
1839                self.assertEqual(model.buf, initial_buf)
1840
1841        _check_parametrization(Subclassify, nn.Parameter, TwoTensor)
1842        _check_parametrization(UnSubclassify, TwoTensor, Tensor)
1843        _check_parametrization(
1844            IdentityWithRightInverse,
1845            nn.Parameter,
1846            TwoTensor,
1847            type_after_right_inverse=TwoTensor,
1848        )
1849        _check_parametrization(
1850            Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True
1851        )
1852        _check_parametrization(
1853            UnSubclassify, TwoTensor, Tensor, leave_parametrized=True
1854        )
1855        _check_parametrization(
1856            IdentityWithRightInverse,
1857            nn.Parameter,
1858            TwoTensor,
1859            leave_parametrized=True,
1860            type_after_right_inverse=TwoTensor,
1861        )
1862
1863
1864class TestNNParametrizationDevice(NNTestCase):
1865    @swap([True, False])
1866    def test_weight_norm_parametrization(self, device):
1867        for dtype in [torch.float, torch.bfloat16]:
1868            input = torch.randn(3, 4, dtype=dtype, device=device)
1869            m = nn.Linear(4, 5, dtype=dtype, device=device)
1870            expected_output = m(input)
1871
1872            # add weight normalization
1873            m = torch.nn.utils.parametrizations.weight_norm(m)
1874            self.assertEqual(
1875                m.parametrizations.weight.original1.size(), m.weight.size()
1876            )
1877            self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
1878            self.assertEqual(m(input), expected_output)
1879
1880            # remove weight norm
1881            torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1882            self.assertFalse(hasattr(m, "parametrizations"))
1883            self.assertEqual(m(input), expected_output)
1884
1885            # test with dim=1
1886            m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
1887            self.assertEqual(
1888                m.parametrizations.weight.original1.size(), m.weight.size()
1889            )
1890            self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
1891            self.assertEqual(m(input), expected_output)
1892
1893            # test with dim=None
1894            m = nn.Linear(4, 5, dtype=dtype, device=device)
1895            expected_output = m(input)
1896            m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
1897            self.assertEqual(m(input), expected_output)
1898
1899
1900only_for = ("cpu", "cuda")
1901instantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for)
1902instantiate_parametrized_tests(TestNNParametrization)
1903
1904if __name__ == "__main__":
1905    run_tests()
1906