xref: /aosp_15_r20/external/pytorch/test/nn/test_parametrization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Workerimport pickle
3*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.init as init
10*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.parametrize as parametrize
11*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
12*da0073e9SAndroid Build Coastguard Workerfrom torch.__future__ import get_swap_module_params_on_conversion
13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Buffer, Parameter
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
18*da0073e9SAndroid Build Coastguard Worker    gradcheck,
19*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
20*da0073e9SAndroid Build Coastguard Worker    run_tests,
21*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
22*da0073e9SAndroid Build Coastguard Worker    skipIfNoLapack,
23*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
24*da0073e9SAndroid Build Coastguard Worker    swap,
25*da0073e9SAndroid Build Coastguard Worker    TemporaryFileName,
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerclass TestNNParametrization(NNTestCase):
31*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
32*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    # FIXME: Rewrite this test using functions not depending on LAPACK
35*da0073e9SAndroid Build Coastguard Worker    #        and remove the `@skipIfNoLapack` (see #70995)
36*da0073e9SAndroid Build Coastguard Worker    # torch/nn/utils/parametrize
37*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
38*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
39*da0073e9SAndroid Build Coastguard Worker    def test_register_and_remove_parametrization(self):
40*da0073e9SAndroid Build Coastguard Worker        r"""Test that it is possible to add a few parametrizations
41*da0073e9SAndroid Build Coastguard Worker        on a parameter or a buffer and that removing them restores the initial state
42*da0073e9SAndroid Build Coastguard Worker        It also tests that backpropagating through them works as expected
43*da0073e9SAndroid Build Coastguard Worker        """
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        # Define a couple matrix parametrizations
46*da0073e9SAndroid Build Coastguard Worker        class Skew(nn.Module):
47*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
48*da0073e9SAndroid Build Coastguard Worker                X = X.tril(-1)
49*da0073e9SAndroid Build Coastguard Worker                return X - X.T
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker        class Orthogonal(nn.Module):
52*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
53*da0073e9SAndroid Build Coastguard Worker                # Cayley map
54*da0073e9SAndroid Build Coastguard Worker                # If X is skew-symmetric it returns an orthogonal matrix
55*da0073e9SAndroid Build Coastguard Worker                Id = torch.eye(X.size(0), device=X.device)
56*da0073e9SAndroid Build Coastguard Worker                # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous
57*da0073e9SAndroid Build Coastguard Worker                # and autograd raises a performance warning.
58*da0073e9SAndroid Build Coastguard Worker                # This happens when we remove the parametrization with leave_parametrized=True,
59*da0073e9SAndroid Build Coastguard Worker                # which does a set_ with a non-contiguous tensor while the gradient is contiguous
60*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.solve(Id + X, Id - X).contiguous()
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker        class Resize(nn.Module):
63*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
64*da0073e9SAndroid Build Coastguard Worker                return X[[0]]
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        class NoResize(nn.Module):
67*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
68*da0073e9SAndroid Build Coastguard Worker                return X
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        # Define a couple vector parametrizations
71*da0073e9SAndroid Build Coastguard Worker        class FirstZero(nn.Module):
72*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
73*da0073e9SAndroid Build Coastguard Worker                return torch.cat([x.new_zeros(1), x[1:]])
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        class LastZero(nn.Module):
76*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
77*da0073e9SAndroid Build Coastguard Worker                return torch.cat([x[:-1], x.new_zeros(1)])
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(8, 8)
80*da0073e9SAndroid Build Coastguard Worker        initial_weight_id = id(model.weight)
81*da0073e9SAndroid Build Coastguard Worker        initial_bias_id = id(model.bias)
82*da0073e9SAndroid Build Coastguard Worker        initial_model = deepcopy(model)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        # Test unsafe flag
85*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
86*da0073e9SAndroid Build Coastguard Worker            ValueError,
87*da0073e9SAndroid Build Coastguard Worker            "Registering a parametrization may not change the shape of the tensor",
88*da0073e9SAndroid Build Coastguard Worker        ):
89*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(
90*da0073e9SAndroid Build Coastguard Worker                model, "weight", Resize()
91*da0073e9SAndroid Build Coastguard Worker            )  # default unsafe = False
92*da0073e9SAndroid Build Coastguard Worker            model(torch.ones(8, 8))
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        # One parametrization with unsafe=True
95*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
96*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model, "parametrizations"))
97*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
98*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
99*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
100*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("weight", model._parameters)
101*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(model.weight.shape[0] == 1)
102*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
103*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, initial_model.weight)
105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(model.weight), initial_weight_id)
106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        # Two parametrizations with unsafe=True
109*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
110*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False)
111*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model, "parametrizations"))
112*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
113*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
114*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
115*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("weight", model._parameters)
116*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(model.weight.shape[0] == 1)
117*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
118*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, initial_model.weight)
120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(model.weight), initial_weight_id)
121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        # Test unsafe flag doesn't change expected behavior
124*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Skew(), unsafe=True)
125*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model, "parametrizations"))
126*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
127*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
128*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
129*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("weight", model._parameters)
130*da0073e9SAndroid Build Coastguard Worker        # Result should be skew-symmetric
131*da0073e9SAndroid Build Coastguard Worker        A = model.weight
132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, -A.T)
133*da0073e9SAndroid Build Coastguard Worker        if get_swap_module_params_on_conversion():
134*da0073e9SAndroid Build Coastguard Worker            # When using the swap_tensors path, this is needed so that the autograd
135*da0073e9SAndroid Build Coastguard Worker            # graph is not alive anymore.
136*da0073e9SAndroid Build Coastguard Worker            del A
137*da0073e9SAndroid Build Coastguard Worker        # Remove and check consistency
138*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
139*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, initial_model.weight)
141*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(model.weight), initial_weight_id)
142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        # Test one parametrization
145*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Skew())
146*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model, "parametrizations"))
147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
148*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
149*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
150*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("weight", model._parameters)
151*da0073e9SAndroid Build Coastguard Worker        # Result should be skew-symmetric
152*da0073e9SAndroid Build Coastguard Worker        A = model.weight
153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, -A.T)
154*da0073e9SAndroid Build Coastguard Worker        if get_swap_module_params_on_conversion():
155*da0073e9SAndroid Build Coastguard Worker            # When using the swap_tensors path, this is needed so that the autograd
156*da0073e9SAndroid Build Coastguard Worker            # graph is not alive anymore.
157*da0073e9SAndroid Build Coastguard Worker            del A
158*da0073e9SAndroid Build Coastguard Worker        # Remove and check consistency
159*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
160*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, initial_model.weight)
162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(model.weight), initial_weight_id)
163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker        # Test two parametrizations at the same time and removing them
166*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Skew())
167*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Orthogonal())
168*da0073e9SAndroid Build Coastguard Worker        # Result should be orthogonal
169*da0073e9SAndroid Build Coastguard Worker        X = model.weight
170*da0073e9SAndroid Build Coastguard Worker        Id = torch.eye(X.size(0), device=X.device)
171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(X.T @ X, Id)
172*da0073e9SAndroid Build Coastguard Worker        if get_swap_module_params_on_conversion():
173*da0073e9SAndroid Build Coastguard Worker            # When using the swap_tensors path, this is needed so that the autograd
174*da0073e9SAndroid Build Coastguard Worker            # graph is not alive anymore.
175*da0073e9SAndroid Build Coastguard Worker            del X
176*da0073e9SAndroid Build Coastguard Worker        # Structure tests
177*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model, "parametrizations"))
178*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
179*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
180*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
181*da0073e9SAndroid Build Coastguard Worker        self.assertIn("weight", model.parametrizations)
182*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("weight", model._parameters)
183*da0073e9SAndroid Build Coastguard Worker        # Remove
184*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, initial_model.weight)
186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(model.weight), initial_weight_id)
187*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        # Add everything
191*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Skew())
192*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Orthogonal())
193*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "bias", FirstZero())
194*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "bias", LastZero())
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        # Basic tests
197*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
198*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
199*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "bias"))
200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[0].item(), 0.0)
201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[-1].item(), 0.0)
202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
203*da0073e9SAndroid Build Coastguard Worker            len(list(model.parameters())), 2
204*da0073e9SAndroid Build Coastguard Worker        )  # Nothing weird has happpened
205*da0073e9SAndroid Build Coastguard Worker        # Should not throw
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        sgd = torch.optim.SGD(model.parameters(), lr=0.01)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        weight_copy = model.weight.clone()
210*da0073e9SAndroid Build Coastguard Worker        bias_copy = model.bias.clone()
211*da0073e9SAndroid Build Coastguard Worker        sgd.zero_grad()
212*da0073e9SAndroid Build Coastguard Worker        (model.weight.T @ model.bias).sum().backward()
213*da0073e9SAndroid Build Coastguard Worker        sgd.step()
214*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.weight, weight_copy)
215*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.bias, bias_copy)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        # Remove first parametrization.
218*da0073e9SAndroid Build Coastguard Worker        # Check that the model is still parametrized and so is the second parameter
219*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
220*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))  # Still parametrized
221*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
222*da0073e9SAndroid Build Coastguard Worker            parametrize.is_parametrized(model, "weight")
223*da0073e9SAndroid Build Coastguard Worker        )  # Parametrization removed
224*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
225*da0073e9SAndroid Build Coastguard Worker            parametrize.is_parametrized(model, "bias")
226*da0073e9SAndroid Build Coastguard Worker        )  # Still parametrized
227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[0].item(), 0.0)  # Still parametrized
228*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[-1].item(), 0.0)  # Still parametrized
229*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.weight, initial_model.weight)  # Has been updated
230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(model.weight), initial_weight_id)  # Keeps the same id
231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happened
232*da0073e9SAndroid Build Coastguard Worker        # Should not throw
233*da0073e9SAndroid Build Coastguard Worker        weight_copy = model.weight.clone()
234*da0073e9SAndroid Build Coastguard Worker        bias_copy = model.bias.clone()
235*da0073e9SAndroid Build Coastguard Worker        sgd.zero_grad()
236*da0073e9SAndroid Build Coastguard Worker        (model.weight.T @ model.bias).sum().backward()
237*da0073e9SAndroid Build Coastguard Worker        sgd.step()
238*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.weight, weight_copy)
239*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.bias, bias_copy)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        # Remove the second parametrization.
242*da0073e9SAndroid Build Coastguard Worker        # Check that the module is not parametrized
243*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
244*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model))  # Not parametrized
245*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.bias, initial_model.bias)  # Has been updated
246*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.bias[0].item(), 0.0)  # Not parametrized
247*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.bias[-1].item(), 0.0)  # Not parametrized
248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(model.bias), initial_bias_id)  # Keeps the same id
249*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
250*da0073e9SAndroid Build Coastguard Worker            hasattr(model, "parametrizations")
251*da0073e9SAndroid Build Coastguard Worker        )  # Not parametrized the module
252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)  # Resores the previous class
253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happeed
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        # Should not throw things are updated
256*da0073e9SAndroid Build Coastguard Worker        weight_copy = model.weight.clone()
257*da0073e9SAndroid Build Coastguard Worker        bias_copy = model.bias.clone()
258*da0073e9SAndroid Build Coastguard Worker        sgd.zero_grad()
259*da0073e9SAndroid Build Coastguard Worker        (model.weight.T @ model.bias).sum().backward()
260*da0073e9SAndroid Build Coastguard Worker        sgd.step()
261*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.weight, weight_copy)
262*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model.bias, bias_copy)
263*da0073e9SAndroid Build Coastguard Worker        if get_swap_module_params_on_conversion():
264*da0073e9SAndroid Build Coastguard Worker            # When using the swap_tensors path, this is needed so that the autograd
265*da0073e9SAndroid Build Coastguard Worker            # graph is not alive anymore.
266*da0073e9SAndroid Build Coastguard Worker            del weight_copy, bias_copy
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        # Test leave_parametrized=True
269*da0073e9SAndroid Build Coastguard Worker        for _ in range(2):
270*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model, "weight", Skew())
271*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model, "weight", Orthogonal())
272*da0073e9SAndroid Build Coastguard Worker            parametrize.remove_parametrizations(
273*da0073e9SAndroid Build Coastguard Worker                model, "weight", leave_parametrized=True
274*da0073e9SAndroid Build Coastguard Worker            )
275*da0073e9SAndroid Build Coastguard Worker            # We didn't change the dtype nor had multiple inputs, so the id should be the same
276*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(id(model.weight), initial_weight_id)
277*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(id(model.bias), initial_bias_id)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker            # Should not throw. Things are updated
280*da0073e9SAndroid Build Coastguard Worker            weight_copy = model.weight.clone()
281*da0073e9SAndroid Build Coastguard Worker            bias_copy = model.bias.clone()
282*da0073e9SAndroid Build Coastguard Worker            sgd.zero_grad()
283*da0073e9SAndroid Build Coastguard Worker            (model.weight.T @ model.bias).sum().backward()
284*da0073e9SAndroid Build Coastguard Worker            sgd.step()
285*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(model.weight, weight_copy)
286*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(model.bias, bias_copy)
287*da0073e9SAndroid Build Coastguard Worker            if get_swap_module_params_on_conversion():
288*da0073e9SAndroid Build Coastguard Worker                # When using the swap_tensors path, this is needed so that the autograd
289*da0073e9SAndroid Build Coastguard Worker                # graph is not alive anymore.
290*da0073e9SAndroid Build Coastguard Worker                del weight_copy, bias_copy
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
293*da0073e9SAndroid Build Coastguard Worker    def test_register_and_remove_nested_parametrization(self):
294*da0073e9SAndroid Build Coastguard Worker        r"""Test that it is possible to nest the parametrizations
295*da0073e9SAndroid Build Coastguard Worker        meaning that the original param is parametrized again
296*da0073e9SAndroid Build Coastguard Worker        """
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        class Skew(nn.Module):
299*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
300*da0073e9SAndroid Build Coastguard Worker                X = X.tril(-1)
301*da0073e9SAndroid Build Coastguard Worker                return X - X.T
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(8, 8)
304*da0073e9SAndroid Build Coastguard Worker        # Add top level parametrization
305*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Skew())
306*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model, "parametrizations"))
307*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
308*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
309*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
310*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("weight", model._parameters)
311*da0073e9SAndroid Build Coastguard Worker        # Result should be skew-symmetric
312*da0073e9SAndroid Build Coastguard Worker        A = model.weight
313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, -A.T)
314*da0073e9SAndroid Build Coastguard Worker        if get_swap_module_params_on_conversion():
315*da0073e9SAndroid Build Coastguard Worker            # When using the swap_tensors path, this is needed so that the autograd
316*da0073e9SAndroid Build Coastguard Worker            # graph is not alive anymore.
317*da0073e9SAndroid Build Coastguard Worker            del A
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        # Add nested parametrization
320*da0073e9SAndroid Build Coastguard Worker        param_mod = model.parametrizations.weight
321*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(param_mod, "parametrizations"))
322*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(param_mod))
323*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(param_mod, "original"))
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(param_mod, "original", Skew())
326*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(param_mod, "parametrizations"))
327*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(param_mod))
328*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(param_mod, "original"))
329*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("original", param_mod._parameters)
330*da0073e9SAndroid Build Coastguard Worker        # Result should be skew-symmetric
331*da0073e9SAndroid Build Coastguard Worker        A = param_mod.original
332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(A, -A.T)
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        # Remove nested param and check consistency
335*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(
336*da0073e9SAndroid Build Coastguard Worker            param_mod, "original", leave_parametrized=False
337*da0073e9SAndroid Build Coastguard Worker        )
338*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(param_mod, "parametrizations"))
339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param_mod.__class__, parametrize.ParametrizationList)
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        # Remove top level and check consistency
342*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
343*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
347*da0073e9SAndroid Build Coastguard Worker    def test_register_and_remove_buffer_parametrization(self):
348*da0073e9SAndroid Build Coastguard Worker        r"""Test that it is possible to add and remove parametrizations on buffers"""
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        # Define a couple vector parametrizations
351*da0073e9SAndroid Build Coastguard Worker        class FirstZero(nn.Module):
352*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
353*da0073e9SAndroid Build Coastguard Worker                return torch.cat([x.new_zeros(1), x[1:]])
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker        class LastZero(nn.Module):
356*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
357*da0073e9SAndroid Build Coastguard Worker                return torch.cat([x[:-1], x.new_zeros(1)])
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(8, 8)
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        # Instantiate parametrizations on buffers. It should work as expected
362*da0073e9SAndroid Build Coastguard Worker        delattr(model, "bias")
363*da0073e9SAndroid Build Coastguard Worker        model.bias = Buffer(torch.ones(8))
364*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "bias", FirstZero())
365*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "bias", LastZero())
366*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
367*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "bias"))
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[0].item(), 0.0)
369*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[-1].item(), 0.0)
370*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(model.parameters())), 1)
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        # Remove parametrizations on buffers. It should work as expected
374*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
375*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model))
376*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
377*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[0].item(), 0.0)
378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.bias[-1].item(), 0.0)
379*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(model.parameters())), 1)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    # FIXME: Rewrite this test using functions not depending on LAPACK
383*da0073e9SAndroid Build Coastguard Worker    #        and remove the `@skipIfNoLapack` (see #70995)
384*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
385*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
386*da0073e9SAndroid Build Coastguard Worker    def test_serialization_parametrization(self):
387*da0073e9SAndroid Build Coastguard Worker        r"""Test that it is possible to serialize a parametrized model via state_dict"""
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker        # A stateful parametrization
390*da0073e9SAndroid Build Coastguard Worker        class Orthogonal(nn.Module):
391*da0073e9SAndroid Build Coastguard Worker            def __init__(self, n):
392*da0073e9SAndroid Build Coastguard Worker                super().__init__()
393*da0073e9SAndroid Build Coastguard Worker                self.id = Buffer(torch.eye(n))
394*da0073e9SAndroid Build Coastguard Worker                self.B = Buffer(torch.empty(n, n))
395*da0073e9SAndroid Build Coastguard Worker                init.orthogonal_(self.B)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
398*da0073e9SAndroid Build Coastguard Worker                A = X.triu(1)
399*da0073e9SAndroid Build Coastguard Worker                A = A - A.T
400*da0073e9SAndroid Build Coastguard Worker                return self.B @ torch.linalg.solve(self.id + A, self.id - A)
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        def get_model():
403*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.Sequential(
404*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(5, 5),
405*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU(),
406*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(5, 1),
407*da0073e9SAndroid Build Coastguard Worker            )
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
410*da0073e9SAndroid Build Coastguard Worker            return model
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker        model = get_model()
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker        prev_weight = model[0].weight
415*da0073e9SAndroid Build Coastguard Worker        prev_B = model[0].parametrizations.weight[0].B
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        new_model = get_model()
418*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
419*da0073e9SAndroid Build Coastguard Worker            torch.save(model.state_dict(), fname)
420*da0073e9SAndroid Build Coastguard Worker            new_model.load_state_dict(torch.load(fname))
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        # Integrity tests
423*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prev_weight, new_model[0].weight)
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        # Trying to save the whole parametrized model raises
428*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "state_dict"):
429*da0073e9SAndroid Build Coastguard Worker            with TemporaryFileName() as fname:
430*da0073e9SAndroid Build Coastguard Worker                torch.save(model, fname)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    # FIXME: Rewrite this test using functions not depending on LAPACK
433*da0073e9SAndroid Build Coastguard Worker    #        and remove the `@skipIfNoLapack` (see #70995)
434*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
435*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
436*da0073e9SAndroid Build Coastguard Worker    def test_initialization_parametrization(self):
437*da0073e9SAndroid Build Coastguard Worker        r"""Test that it is possible to initialize a parametrization when it
438*da0073e9SAndroid Build Coastguard Worker        implements a `right_inverse` method
439*da0073e9SAndroid Build Coastguard Worker        """
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        class Skew(nn.Module):
442*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
443*da0073e9SAndroid Build Coastguard Worker                A = X.triu(1)
444*da0073e9SAndroid Build Coastguard Worker                return A - A.T
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker            def is_skew(self, A):
447*da0073e9SAndroid Build Coastguard Worker                return torch.allclose(A, -A.T, atol=1e-6)
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, X):
450*da0073e9SAndroid Build Coastguard Worker                if not self.is_skew(X):
451*da0073e9SAndroid Build Coastguard Worker                    raise ValueError("The matrix is not skew-symmetric.")
452*da0073e9SAndroid Build Coastguard Worker                return X.triu(1)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        # Implements a Cayley map where right_inverse is not quite the inverse of forward
455*da0073e9SAndroid Build Coastguard Worker        class Orthogonal(nn.Module):
456*da0073e9SAndroid Build Coastguard Worker            def __init__(self, n):
457*da0073e9SAndroid Build Coastguard Worker                super().__init__()
458*da0073e9SAndroid Build Coastguard Worker                self.B = Buffer(torch.eye(n))
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
461*da0073e9SAndroid Build Coastguard Worker                Id = torch.eye(X.size(0))
462*da0073e9SAndroid Build Coastguard Worker                return self.B @ torch.linalg.solve(Id + X, Id - X)
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker            def is_orthogonal(self, X):
465*da0073e9SAndroid Build Coastguard Worker                Id = torch.eye(X.size(0))
466*da0073e9SAndroid Build Coastguard Worker                return torch.allclose(X.T @ X, Id, atol=1e-4)
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, X):
469*da0073e9SAndroid Build Coastguard Worker                if not self.is_orthogonal(X):
470*da0073e9SAndroid Build Coastguard Worker                    raise ValueError("The input is not orthogonal.")
471*da0073e9SAndroid Build Coastguard Worker                # cayley(0) == Id, so B @ cayley(0) == B
472*da0073e9SAndroid Build Coastguard Worker                self.B = X
473*da0073e9SAndroid Build Coastguard Worker                return torch.zeros_like(X)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        N = 5
476*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(N, N)
477*da0073e9SAndroid Build Coastguard Worker        # Register the skew-symmetric constraint. The result is now skew-symmetric
478*da0073e9SAndroid Build Coastguard Worker        skew = Skew()
479*da0073e9SAndroid Build Coastguard Worker        # Make the weight skew-symmetric before registering the parametrization
480*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
481*da0073e9SAndroid Build Coastguard Worker            model.weight.set_(skew(model.weight))
482*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", skew)
483*da0073e9SAndroid Build Coastguard Worker        X = torch.rand(N, N)
484*da0073e9SAndroid Build Coastguard Worker        # X is not skew-symmetric, so it throws an error
485*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
486*da0073e9SAndroid Build Coastguard Worker            model.weight = X
487*da0073e9SAndroid Build Coastguard Worker        # Make X skew-symmetric
488*da0073e9SAndroid Build Coastguard Worker        X = X - X.T
489*da0073e9SAndroid Build Coastguard Worker        model.weight = X
490*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.parametrizations.weight.original, X.triu(1))
491*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, X)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        # Having several parametrizations registered should work in the same way
494*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Orthogonal(N))
495*da0073e9SAndroid Build Coastguard Worker        # Register now the Cayley map. The result is now orthogonal
496*da0073e9SAndroid Build Coastguard Worker        X = torch.rand(N, N)
497*da0073e9SAndroid Build Coastguard Worker        # X is not orthogonal, so it throws an error
498*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
499*da0073e9SAndroid Build Coastguard Worker            model.weight = X
500*da0073e9SAndroid Build Coastguard Worker        init.orthogonal_(X)
501*da0073e9SAndroid Build Coastguard Worker        model.weight = X
502*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, X)
503*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
506*da0073e9SAndroid Build Coastguard Worker    def test_errors_unparametrized_tensor_parametrization(self):
507*da0073e9SAndroid Build Coastguard Worker        # Test errors when registering a parametrization on an unparametrized tensor
508*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(3, 4)
509*da0073e9SAndroid Build Coastguard Worker        weight_init = module.weight.clone()
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        class Identity(nn.Module):
512*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
513*da0073e9SAndroid Build Coastguard Worker                return x
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        # Register a parametrization on a non-existing parameter throws
516*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "does not have a parameter"):
517*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "foo", Identity())
518*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        # Removing parametrizations from an unparametrized tensor throws
521*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
522*da0073e9SAndroid Build Coastguard Worker            parametrize.remove_parametrizations(module, "bias")
523*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker        # A correct parametrization with several outputs
526*da0073e9SAndroid Build Coastguard Worker        class Sum(nn.Module):
527*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
528*da0073e9SAndroid Build Coastguard Worker                return x + y
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, z):
531*da0073e9SAndroid Build Coastguard Worker                return z, torch.zeros_like(z)
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(module, "weight", Sum())
534*da0073e9SAndroid Build Coastguard Worker        # Cannot remove a parametrization with several outputs with `leave_parametrized=False`
535*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
536*da0073e9SAndroid Build Coastguard Worker            parametrize.remove_parametrizations(
537*da0073e9SAndroid Build Coastguard Worker                module, "weight", leave_parametrized=False
538*da0073e9SAndroid Build Coastguard Worker            )
539*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker        # A parametrization with an incorrect number of outputs
542*da0073e9SAndroid Build Coastguard Worker        class WrongNumberParams(nn.Module):
543*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y, z):
544*da0073e9SAndroid Build Coastguard Worker                return x + y + z
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, w):
547*da0073e9SAndroid Build Coastguard Worker                return w, torch.zeros_like(w)
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        # Makes param(*param.right_inverse(X)) fail
550*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "positional argument"):
551*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", WrongNumberParams())
552*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor]
555*da0073e9SAndroid Build Coastguard Worker        class WrongRightInverse(Identity):
556*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, z):
557*da0073e9SAndroid Build Coastguard Worker                return None
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        # right_inverse should return a Tensor or a Sequence[Tensor]
560*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"):
561*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", WrongRightInverse())
562*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        # If it's a sequence, it must to be a sequence of tensors
565*da0073e9SAndroid Build Coastguard Worker        class WrongRightInverseSequence(nn.Module):
566*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
567*da0073e9SAndroid Build Coastguard Worker                return x
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, z):
570*da0073e9SAndroid Build Coastguard Worker                return None, z
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "of the sequence with type"):
573*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(
574*da0073e9SAndroid Build Coastguard Worker                module, "weight", WrongRightInverseSequence()
575*da0073e9SAndroid Build Coastguard Worker            )
576*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        # A parametrization from one tensor to one tensor that changes the dtype
579*da0073e9SAndroid Build Coastguard Worker        class ChangeDtypeInverse(nn.Module):
580*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
581*da0073e9SAndroid Build Coastguard Worker                return x.float()
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, w):
584*da0073e9SAndroid Build Coastguard Worker                return w.bool()
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker        # For parametrizations that return one tensor, right_inverse may not change the dtype
587*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
588*da0073e9SAndroid Build Coastguard Worker            ValueError, "outputs one tensor, it may not change the dtype"
589*da0073e9SAndroid Build Coastguard Worker        ):
590*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
591*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker        # Doesn't return a tensor
594*da0073e9SAndroid Build Coastguard Worker        class NotTensor(nn.Module):
595*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
596*da0073e9SAndroid Build Coastguard Worker                return 2
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker        # Forward must return a tensor
599*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "must return a tensor"):
600*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", NotTensor())
601*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker        # A parametrization from one tensor to one tensor that changes the dtype
604*da0073e9SAndroid Build Coastguard Worker        class ChangeDtype(nn.Module):
605*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
606*da0073e9SAndroid Build Coastguard Worker                return x.bool()
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        # forward should not change the initial dtype
609*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
610*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeDtype())
611*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        # Change shape
614*da0073e9SAndroid Build Coastguard Worker        class ChangeShape(nn.Module):
615*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
616*da0073e9SAndroid Build Coastguard Worker                return x[:-1]
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker        # forward should not change the original shape
619*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "may not change the shape"):
620*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeShape())
621*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker        # Many to one that changes dtype
624*da0073e9SAndroid Build Coastguard Worker        class ChangeDtypeMulti(nn.Module):
625*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
626*da0073e9SAndroid Build Coastguard Worker                return (x + y).bool()
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, w):
629*da0073e9SAndroid Build Coastguard Worker                return w, w + 1
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        # forward should not change the original shape even for parametrizations with many inputs
632*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
633*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeDtypeMulti())
634*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        # Returning a sequence of size one, although weird, it's correct
637*da0073e9SAndroid Build Coastguard Worker        class SequenceLen1(nn.Module):
638*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
639*da0073e9SAndroid Build Coastguard Worker                return x
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, w):
642*da0073e9SAndroid Build Coastguard Worker                return (w,)
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(module, "weight", SequenceLen1())
645*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(module.parametrizations.weight, "original0"))
646*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(module.parametrizations.weight, "original1"))
647*da0073e9SAndroid Build Coastguard Worker        _ = module.weight  # Does not throw
648*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(module))
649*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker        # None of the operations above should have altered the weight
652*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(module))
653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.weight, weight_init)
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
656*da0073e9SAndroid Build Coastguard Worker    def test_errors_parametrized_tensor_parametrization(self):
657*da0073e9SAndroid Build Coastguard Worker        # Test errors when registering a parametrization on a parametrized tensor
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker        class Identity(nn.Module):
660*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
661*da0073e9SAndroid Build Coastguard Worker                return x
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(3, 4)
664*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(module, "weight", Identity())
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        # Has to return a tensor
667*da0073e9SAndroid Build Coastguard Worker        class WrongReturn(nn.Module):
668*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
669*da0073e9SAndroid Build Coastguard Worker                return x, x
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "must return a tensor"):
672*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", WrongReturn())
673*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(module))
674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module.parametrizations.weight), 1)
675*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        # Cannot change dtype
678*da0073e9SAndroid Build Coastguard Worker        class ChangeDtype(nn.Module):
679*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
680*da0073e9SAndroid Build Coastguard Worker                return x.bool()
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
683*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeDtype())
684*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(module))
685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module.parametrizations.weight), 1)
686*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker        # Cannot change shape
689*da0073e9SAndroid Build Coastguard Worker        class ChangeShape(nn.Module):
690*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
691*da0073e9SAndroid Build Coastguard Worker                return x[:-1]
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "may not change the shape"):
694*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeShape())
695*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(module))
696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module.parametrizations.weight), 1)
697*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        # The following checks are mostly due to bugs in the code of the parametrization
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        # right_inverse has to return a tensor
702*da0073e9SAndroid Build Coastguard Worker        class WrongReturnInverse(Identity):
703*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, x):
704*da0073e9SAndroid Build Coastguard Worker                return x, x
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"):
707*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", WrongReturnInverse())
708*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(module))
709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module.parametrizations.weight), 1)
710*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker        # Cannot change dtype
713*da0073e9SAndroid Build Coastguard Worker        class ChangeDtypeInverse(Identity):
714*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, x):
715*da0073e9SAndroid Build Coastguard Worker                return x.bool()
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "must have the same dtype"):
718*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
719*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(module))
720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module.parametrizations.weight), 1)
721*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        # Cannot change shape
724*da0073e9SAndroid Build Coastguard Worker        class ChangeShapeInverse(Identity):
725*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, x):
726*da0073e9SAndroid Build Coastguard Worker                return x[:-1]
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "must have the same shape"):
729*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(module, "weight", ChangeShapeInverse())
730*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(module))
731*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module.parametrizations.weight), 1)
732*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker    # FIXME: Rewrite this test using functions not depending on LAPACK
735*da0073e9SAndroid Build Coastguard Worker    #        and remove the `@skipIfNoLapack` (see #70995)
736*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
737*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
738*da0073e9SAndroid Build Coastguard Worker    def test_multiple_inputs_parametrization(self):
739*da0073e9SAndroid Build Coastguard Worker        # A parametrization with several outputs
740*da0073e9SAndroid Build Coastguard Worker        class RankOne(nn.Module):
741*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
742*da0073e9SAndroid Build Coastguard Worker                # Form a rank-1 matrix from a pair of vectors
743*da0073e9SAndroid Build Coastguard Worker                return x.unsqueeze(-1) @ y.unsqueeze(-2)
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, Y):
746*da0073e9SAndroid Build Coastguard Worker                # We project the given matrix onto the rank 1 matrices
747*da0073e9SAndroid Build Coastguard Worker                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
748*da0073e9SAndroid Build Coastguard Worker                # S is ordered in a decreasing way.
749*da0073e9SAndroid Build Coastguard Worker                s0_sqrt = S[0].sqrt().unsqueeze(-1)
750*da0073e9SAndroid Build Coastguard Worker                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        # Simple parametrisation
753*da0073e9SAndroid Build Coastguard Worker        class Double(nn.Module):
754*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
755*da0073e9SAndroid Build Coastguard Worker                return 2.0 * x
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, w):
758*da0073e9SAndroid Build Coastguard Worker                return 0.5 * w
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(3, 3)
761*da0073e9SAndroid Build Coastguard Worker        # Test one parametrization
762*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", RankOne())
763*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model, "parametrizations"))
764*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model))
765*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parametrize.is_parametrized(model, "weight"))
766*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model.parametrizations.weight, "original0"))
767*da0073e9SAndroid Build Coastguard Worker        self.assertIn("original0", model.parametrizations.weight._parameters)
768*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model.parametrizations.weight, "original1"))
769*da0073e9SAndroid Build Coastguard Worker        self.assertIn("original1", model.parametrizations.weight._parameters)
770*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model, "bias"))
771*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("weight", model._parameters)
772*da0073e9SAndroid Build Coastguard Worker        # Result should be rank 1
773*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
776*da0073e9SAndroid Build Coastguard Worker            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
777*da0073e9SAndroid Build Coastguard Worker            parametrize.remove_parametrizations(
778*da0073e9SAndroid Build Coastguard Worker                model, "weight", leave_parametrized=False
779*da0073e9SAndroid Build Coastguard Worker            )
780*da0073e9SAndroid Build Coastguard Worker        # Remove parametrization and check consistency
781*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
782*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
783*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
784*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model))
785*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
786*da0073e9SAndroid Build Coastguard Worker        self.assertIn("weight", model._parameters)
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker        # Registering parametrizations with one input on top of one with multiple inputs should work
789*da0073e9SAndroid Build Coastguard Worker        init_weight = model.weight.clone()
790*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", RankOne())
791*da0073e9SAndroid Build Coastguard Worker        # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix
792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(init_weight, model.weight)
793*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Double())
794*da0073e9SAndroid Build Coastguard Worker        # The matrix now is twice the initial matrix
795*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2.0 * init_weight, model.weight)
796*da0073e9SAndroid Build Coastguard Worker        # Multiplying by a scalar does not change the rank
797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker        # The model has now three parameters
800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(model.parameters())), 3)
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker        # Test backward. Should not throw
805*da0073e9SAndroid Build Coastguard Worker        for _ in range(2):
806*da0073e9SAndroid Build Coastguard Worker            sgd.zero_grad()
807*da0073e9SAndroid Build Coastguard Worker            loss = (model.weight.T @ model.bias).sum()
808*da0073e9SAndroid Build Coastguard Worker            loss.backward()
809*da0073e9SAndroid Build Coastguard Worker            sgd.step()
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker        # Same drill as before, removing should work as expected
812*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
813*da0073e9SAndroid Build Coastguard Worker            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
814*da0073e9SAndroid Build Coastguard Worker            parametrize.remove_parametrizations(
815*da0073e9SAndroid Build Coastguard Worker                model, "weight", leave_parametrized=False
816*da0073e9SAndroid Build Coastguard Worker            )
817*da0073e9SAndroid Build Coastguard Worker        # Remove parametrization and check consistency
818*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
819*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(model, "parametrizations"))
820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.__class__, nn.Linear)
821*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parametrize.is_parametrized(model))
822*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
823*da0073e9SAndroid Build Coastguard Worker        self.assertIn("weight", model._parameters)
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker        # The model has now two parameters
826*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(model.parameters())), 2)
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker        # Test backward. Should not throw
829*da0073e9SAndroid Build Coastguard Worker        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
830*da0073e9SAndroid Build Coastguard Worker        for _ in range(2):
831*da0073e9SAndroid Build Coastguard Worker            sgd.zero_grad()
832*da0073e9SAndroid Build Coastguard Worker            loss = (model.weight.T @ model.bias).sum()
833*da0073e9SAndroid Build Coastguard Worker            loss.backward()
834*da0073e9SAndroid Build Coastguard Worker            sgd.step()
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker    # FIXME: Rewrite this test using functions not depending on LAPACK
837*da0073e9SAndroid Build Coastguard Worker    #        and remove the `@skipIfNoLapack` (see #70995)
838*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
839*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
840*da0073e9SAndroid Build Coastguard Worker    def test_caching_parametrization(self):
841*da0073e9SAndroid Build Coastguard Worker        r"""Test the caching system of a parametrization"""
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker        # Define a couple matrix parametrizations
844*da0073e9SAndroid Build Coastguard Worker        class Skew(nn.Module):
845*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
846*da0073e9SAndroid Build Coastguard Worker                X = X.tril(-1)
847*da0073e9SAndroid Build Coastguard Worker                return X - X.T
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker        class Orthogonal(nn.Module):
850*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
851*da0073e9SAndroid Build Coastguard Worker                Id = torch.eye(X.size(0), device=X.device)
852*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.solve(Id + X, Id - X)
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(5, 5)
855*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Skew())
856*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Orthogonal())
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker        # Test that the caching system works
859*da0073e9SAndroid Build Coastguard Worker        with parametrize.cached():
860*da0073e9SAndroid Build Coastguard Worker            X = model.weight
861*da0073e9SAndroid Build Coastguard Worker            Y = model.weight
862*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(id(X), id(Y))
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker    # FIXME: Rewrite this test using functions not depending on LAPACK
865*da0073e9SAndroid Build Coastguard Worker    #        and remove the `@skipIfNoLapack` (see #70995)
866*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
867*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
868*da0073e9SAndroid Build Coastguard Worker    def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
869*da0073e9SAndroid Build Coastguard Worker        r"""Test that transferring parametrizations doesn't cause issues with caching"""
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker        class Skew(nn.Module):
872*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
873*da0073e9SAndroid Build Coastguard Worker                X = X.tril(-1)
874*da0073e9SAndroid Build Coastguard Worker                return X - X.T
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker        class Orthogonal(nn.Module):
877*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
878*da0073e9SAndroid Build Coastguard Worker                Id = torch.eye(X.size(0), device=X.device)
879*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.solve(Id + X, Id - X)
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(5, 5)
882*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Skew())
883*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Orthogonal())
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker        to_model = nn.Linear(5, 5)
886*da0073e9SAndroid Build Coastguard Worker        parametrize.transfer_parametrizations_and_params(model, to_model)
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker        with parametrize.cached():
889*da0073e9SAndroid Build Coastguard Worker            X = model.weight
890*da0073e9SAndroid Build Coastguard Worker            Y = model.weight
891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(id(X), id(Y))
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker            A = to_model.weight
894*da0073e9SAndroid Build Coastguard Worker            B = to_model.weight
895*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(id(A), id(B))
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker            # test that the results are distinct objects for each module
898*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(id(A), id(X))
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
901*da0073e9SAndroid Build Coastguard Worker    def test_parametrization_same_training_mode(self):
902*da0073e9SAndroid Build Coastguard Worker        r"""Test training mode updated on parametrization registration"""
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker        class Identity(nn.Module):
905*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
906*da0073e9SAndroid Build Coastguard Worker                return X
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(4, 4)
909*da0073e9SAndroid Build Coastguard Worker        module.eval()
910*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(module, "weight", Identity())
911*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(module.parametrizations.weight[0].training)
912*da0073e9SAndroid Build Coastguard Worker        module.train()
913*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(module, "weight", Identity().eval())
914*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(module.parametrizations.weight[0].training)
915*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(module.parametrizations.weight[1].training)
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
918*da0073e9SAndroid Build Coastguard Worker    def test_type_before_parametrizations(self):
919*da0073e9SAndroid Build Coastguard Worker        r"""Test that type_before_parametrizations always retrieves original type"""
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker        class Identity(nn.Module):
922*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
923*da0073e9SAndroid Build Coastguard Worker                return X
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(5, 5)
926*da0073e9SAndroid Build Coastguard Worker        original_type = type(model)
927*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
928*da0073e9SAndroid Build Coastguard Worker            parametrize.type_before_parametrizations(model) == original_type
929*da0073e9SAndroid Build Coastguard Worker        )
930*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Identity())
931*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
932*da0073e9SAndroid Build Coastguard Worker            parametrize.type_before_parametrizations(model) == original_type
933*da0073e9SAndroid Build Coastguard Worker        )
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
936*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_after_parametrization(self):
937*da0073e9SAndroid Build Coastguard Worker        r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker        class AddOne(nn.Module):
940*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
941*da0073e9SAndroid Build Coastguard Worker                return x + 1.0
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker        class ModelWithoutDeepcopy(nn.Module):
944*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
945*da0073e9SAndroid Build Coastguard Worker                super().__init__()
946*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(
947*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True
948*da0073e9SAndroid Build Coastguard Worker                )
949*da0073e9SAndroid Build Coastguard Worker                self.bias = nn.Parameter(
950*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([0.0, 0.0, 0.0, 0.0]), requires_grad=True
951*da0073e9SAndroid Build Coastguard Worker                )
952*da0073e9SAndroid Build Coastguard Worker                self.attr = [1.0, 2.0, 3.0, 4.0]
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker        class ActualModel(ModelWithoutDeepcopy):
955*da0073e9SAndroid Build Coastguard Worker            # Emulate custom implementation of the deepcopying.
956*da0073e9SAndroid Build Coastguard Worker            def __deepcopy__(self, memo):
957*da0073e9SAndroid Build Coastguard Worker                result = self.__new__(self.__class__)
958*da0073e9SAndroid Build Coastguard Worker                memo[id(self)] = result
959*da0073e9SAndroid Build Coastguard Worker                result.__dict__ = deepcopy(self.__dict__, memo)
960*da0073e9SAndroid Build Coastguard Worker                return result
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker        def check_deepcopy(m1: nn.Module, m2: nn.Module):
963*da0073e9SAndroid Build Coastguard Worker            w1 = m1.parametrizations.weight.original
964*da0073e9SAndroid Build Coastguard Worker            w2 = m2.parametrizations.weight.original
965*da0073e9SAndroid Build Coastguard Worker            b1 = (
966*da0073e9SAndroid Build Coastguard Worker                m1.parametrizations.bias.original
967*da0073e9SAndroid Build Coastguard Worker                if parametrize.is_parametrized(m1, "bias")
968*da0073e9SAndroid Build Coastguard Worker                else m1.bias
969*da0073e9SAndroid Build Coastguard Worker            )
970*da0073e9SAndroid Build Coastguard Worker            b2 = (
971*da0073e9SAndroid Build Coastguard Worker                m2.parametrizations.bias.original
972*da0073e9SAndroid Build Coastguard Worker                if parametrize.is_parametrized(m2, "bias")
973*da0073e9SAndroid Build Coastguard Worker                else m2.bias
974*da0073e9SAndroid Build Coastguard Worker            )
975*da0073e9SAndroid Build Coastguard Worker            # Weights, biases and attributes should be equal but they must be different objects.
976*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
977*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(m1, m2)
978*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(w1, w2)
979*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(w1, w2)
980*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b1, b2)
981*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(b1, b2)
982*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m1.attr, m2.attr)
983*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(m1.attr, m2.attr)
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker        for model in (ModelWithoutDeepcopy(), ActualModel()):
986*da0073e9SAndroid Build Coastguard Worker            # General check that we are able to create deepcopy.
987*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model, "weight", AddOne())
988*da0073e9SAndroid Build Coastguard Worker            check_deepcopy(model, deepcopy(model))
989*da0073e9SAndroid Build Coastguard Worker            # Check that this works on models with several parametrized tensors.
990*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model, "bias", AddOne())
991*da0073e9SAndroid Build Coastguard Worker            check_deepcopy(model, deepcopy(model))
992*da0073e9SAndroid Build Coastguard Worker            # Check that this works on models where tensors have more than one parametrization.
993*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model, "weight", AddOne())
994*da0073e9SAndroid Build Coastguard Worker            check_deepcopy(model, deepcopy(model))
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
997*da0073e9SAndroid Build Coastguard Worker    def test_transfer_parametrizations_and_params(self):
998*da0073e9SAndroid Build Coastguard Worker        r"""Test that all parametrizations and their associated parameters are transferred."""
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        class AddOne(nn.Module):
1001*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1002*da0073e9SAndroid Build Coastguard Worker                return x + 1.0
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker        class Double(nn.Module):
1005*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1006*da0073e9SAndroid Build Coastguard Worker                return 2.0 * x
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, x):
1009*da0073e9SAndroid Build Coastguard Worker                return 0.5 * x
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker        class MinusOne(nn.Module):
1012*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1013*da0073e9SAndroid Build Coastguard Worker                return x - 1.0
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(5, 5)
1016*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", AddOne())
1017*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Double())
1018*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", MinusOne())
1019*da0073e9SAndroid Build Coastguard Worker        hold_weight = model.weight
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker        to_model = torch.ao.nn.qat.Linear(
1022*da0073e9SAndroid Build Coastguard Worker            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1023*da0073e9SAndroid Build Coastguard Worker        )
1024*da0073e9SAndroid Build Coastguard Worker        parametrize.transfer_parametrizations_and_params(model, to_model)
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker        # checks that final and original value are correct and the to_model is parametrized
1027*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, to_model.weight)
1029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1030*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.weight.original,
1031*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.weight.original,
1032*da0073e9SAndroid Build Coastguard Worker        )
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker        # check that the transfer didn't affect the original value
1035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hold_weight, model.weight)
1036*da0073e9SAndroid Build Coastguard Worker        if get_swap_module_params_on_conversion():
1037*da0073e9SAndroid Build Coastguard Worker            # When using the swap_tensors path, this is needed so that the autograd
1038*da0073e9SAndroid Build Coastguard Worker            # graph is not alive anymore.
1039*da0073e9SAndroid Build Coastguard Worker            del hold_weight
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker        # testing that changes to one set of parametrizations do not affect the other
1042*da0073e9SAndroid Build Coastguard Worker        parametrize.remove_parametrizations(to_model, "weight")
1043*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1044*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight"))
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker        # also test that parameters that don't exist in to_model get transferred
1047*da0073e9SAndroid Build Coastguard Worker        model.test_param = Parameter(torch.randn(5, 5))
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(to_model, "test_param"))
1050*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "test_param", Double())
1051*da0073e9SAndroid Build Coastguard Worker        hold_test_param = model.test_param
1052*da0073e9SAndroid Build Coastguard Worker        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        # check that previously missing params got transferred correctly
1055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.test_param, to_model.test_param)
1056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1057*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.test_param.original,
1058*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.test_param.original,
1059*da0073e9SAndroid Build Coastguard Worker        )
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        # check that the new transfer didn't change the value for the from_module
1062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hold_test_param, model.test_param)
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1065*da0073e9SAndroid Build Coastguard Worker    def test_transfer_parametrizations_and_params_right_inverse(self):
1066*da0073e9SAndroid Build Coastguard Worker        r"""Test that all parametrizations and their associated parameters are transferred."""
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker        class Double(nn.Module):
1069*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1070*da0073e9SAndroid Build Coastguard Worker                return 2.0 * x
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, x):
1073*da0073e9SAndroid Build Coastguard Worker                return 0.5 * x
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(5, 5)
1076*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Double())
1077*da0073e9SAndroid Build Coastguard Worker        hold_weight = model.weight
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker        to_model = torch.ao.nn.qat.Linear(
1080*da0073e9SAndroid Build Coastguard Worker            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1081*da0073e9SAndroid Build Coastguard Worker        )
1082*da0073e9SAndroid Build Coastguard Worker        parametrize.transfer_parametrizations_and_params(model, to_model)
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker        # check that transfer occurs successfully
1085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, to_model.weight)
1086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1087*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.weight.original,
1088*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.weight.original,
1089*da0073e9SAndroid Build Coastguard Worker        )
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker        # check that transfer doesn't affect the from_model weight
1092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hold_weight, model.weight)
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1095*da0073e9SAndroid Build Coastguard Worker    def test_transfer_parametrizations_and_params_single_param(self):
1096*da0073e9SAndroid Build Coastguard Worker        r"""Test that all parametrizations and their associated parameters are transferred."""
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker        class AddOne(nn.Module):
1099*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1100*da0073e9SAndroid Build Coastguard Worker                return x + 1.0
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker        class Double(nn.Module):
1103*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1104*da0073e9SAndroid Build Coastguard Worker                return 2.0 * x
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        class MinusOne(nn.Module):
1107*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1108*da0073e9SAndroid Build Coastguard Worker                return x - 1.0
1109*da0073e9SAndroid Build Coastguard Worker
1110*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(5, 5, bias=True)
1111*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", AddOne())
1112*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Double())
1113*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", MinusOne())
1114*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "bias", AddOne())
1115*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "bias", Double())
1116*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "bias", MinusOne())
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker        to_model = torch.ao.nn.qat.Linear(
1119*da0073e9SAndroid Build Coastguard Worker            5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig()
1120*da0073e9SAndroid Build Coastguard Worker        )
1121*da0073e9SAndroid Build Coastguard Worker        parametrize.transfer_parametrizations_and_params(model, to_model, "weight")
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker        # check that weight and only weight was transferred
1124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, to_model.weight)
1125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1126*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.weight.original,
1127*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.weight.original,
1128*da0073e9SAndroid Build Coastguard Worker        )
1129*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("bias" not in to_model.parametrizations)
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker    # FIXME: Rewrite this test using functions not depending on LAPACK
1132*da0073e9SAndroid Build Coastguard Worker    # and remove the `@skipIfNoLapack` (see #70995)
1133*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
1134*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1135*da0073e9SAndroid Build Coastguard Worker    def test_transfer_parametrizations_and_params_many_to_one(self):
1136*da0073e9SAndroid Build Coastguard Worker        # A parametrization with several outputs
1137*da0073e9SAndroid Build Coastguard Worker        class RankOne(nn.Module):
1138*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
1139*da0073e9SAndroid Build Coastguard Worker                # Form a rank-1 matrix from a pair of vectors
1140*da0073e9SAndroid Build Coastguard Worker                return x.unsqueeze(-1) @ y.unsqueeze(-2)
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, Y):
1143*da0073e9SAndroid Build Coastguard Worker                # We project the given matrix onto the rank 1 matrices
1144*da0073e9SAndroid Build Coastguard Worker                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
1145*da0073e9SAndroid Build Coastguard Worker                # S is ordered in a decreasing way.
1146*da0073e9SAndroid Build Coastguard Worker                s0_sqrt = S[0].sqrt().unsqueeze(-1)
1147*da0073e9SAndroid Build Coastguard Worker                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker        class Double(nn.Module):
1150*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1151*da0073e9SAndroid Build Coastguard Worker                return 2.0 * x
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(3, 3)
1154*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", RankOne())
1155*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", Double())
1156*da0073e9SAndroid Build Coastguard Worker        hold_weight = model.weight
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker        to_model = torch.ao.nn.qat.Linear(
1159*da0073e9SAndroid Build Coastguard Worker            3, 3, qconfig=torch.ao.quantization.get_default_qconfig()
1160*da0073e9SAndroid Build Coastguard Worker        )
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker        parametrize.transfer_parametrizations_and_params(model, to_model)
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker        # checks that final and original value are correct and the to_model is parametrized
1165*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.weight, to_model.weight)
1167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1168*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.weight.original0,
1169*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.weight.original0,
1170*da0073e9SAndroid Build Coastguard Worker        )
1171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1172*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.weight.original1,
1173*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.weight.original1,
1174*da0073e9SAndroid Build Coastguard Worker        )
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker        # check that the transfer didn't affect the original value
1177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hold_weight, model.weight)
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker        # testing that changes to one set of parametrizations do not affect the other
1180*da0073e9SAndroid Build Coastguard Worker        model.test_param = Parameter(torch.randn(3, 3))
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(to_model, "test_param"))
1183*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "test_param", RankOne())
1184*da0073e9SAndroid Build Coastguard Worker        hold_test_param = model.test_param
1185*da0073e9SAndroid Build Coastguard Worker        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker        # also check that previously missing params got transferred correctly
1188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.test_param, to_model.test_param)
1189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1190*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.test_param.original0,
1191*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.test_param.original0,
1192*da0073e9SAndroid Build Coastguard Worker        )
1193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1194*da0073e9SAndroid Build Coastguard Worker            model.parametrizations.test_param.original1,
1195*da0073e9SAndroid Build Coastguard Worker            to_model.parametrizations.test_param.original1,
1196*da0073e9SAndroid Build Coastguard Worker        )
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker        # check that the new transfer didn't change the value for the from_module
1199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hold_test_param, model.test_param)
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1202*da0073e9SAndroid Build Coastguard Worker    def test_new_spectral_norm(self):
1203*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
1204*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 5)
1205*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(5, 7)
1206*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrizations.spectral_norm(m)
1207*da0073e9SAndroid Build Coastguard Worker            spectral_norm_m = m.parametrizations.weight[0]
1208*da0073e9SAndroid Build Coastguard Worker
1209*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)]))
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker            # .parametrizations.weight.original should be trainable
1212*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(m.parametrizations.weight, "original"))
1213*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("original" in m.parametrizations.weight._parameters)
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker            # u should be just a reused buffer
1216*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(spectral_norm_m, "_u"))
1217*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("_u" in spectral_norm_m._buffers)
1218*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("_v" in spectral_norm_m._buffers)
1219*da0073e9SAndroid Build Coastguard Worker
1220*da0073e9SAndroid Build Coastguard Worker            # weight should be a plain attribute, not counted as a buffer or a param
1221*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(m.weight)
1222*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("weight" in m._buffers)
1223*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("weight" in m._parameters)
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker            # it should also be sharing storage as `weight_orig`
1226*da0073e9SAndroid Build Coastguard Worker            # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage())
1227*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size())
1228*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1229*da0073e9SAndroid Build Coastguard Worker                m.parametrizations.weight.original.stride(), m.weight.stride()
1230*da0073e9SAndroid Build Coastguard Worker            )
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1233*da0073e9SAndroid Build Coastguard Worker
1234*da0073e9SAndroid Build Coastguard Worker            # spectral_norm is the only parametrization
1235*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(hasattr(m, "parametrizations"))
1236*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("weight" in m._parameters)
1237*da0073e9SAndroid Build Coastguard Worker
1238*da0073e9SAndroid Build Coastguard Worker            # We can register spectral_norm multiple times on the same parameter
1239*da0073e9SAndroid Build Coastguard Worker            # and on multiple parameters in the same module
1240*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1241*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1242*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrizations.spectral_norm(m, "bias")
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker            # If we remove the parametrization on bias, weight is still parametrized
1245*da0073e9SAndroid Build Coastguard Worker            # Removing a parametrization runs forward in eval mode if leave_parametrized=True
1246*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrize.remove_parametrizations(m, "bias")
1247*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("bias" in m._parameters)
1248*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(m, "parametrizations"))
1249*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("weight" in m._parameters)
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1252*da0073e9SAndroid Build Coastguard Worker            # Neither weight and bias are parametrized
1253*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(hasattr(m, "parametrizations"))
1254*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("weight" in m._parameters)
1255*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m))
1256*da0073e9SAndroid Build Coastguard Worker
1257*da0073e9SAndroid Build Coastguard Worker            # test correctness in training/eval modes and cpu/multi-gpu settings
1258*da0073e9SAndroid Build Coastguard Worker            for apply_dp in (True, False):
1259*da0073e9SAndroid Build Coastguard Worker                if apply_dp:
1260*da0073e9SAndroid Build Coastguard Worker                    if not TEST_MULTIGPU:
1261*da0073e9SAndroid Build Coastguard Worker                        continue
1262*da0073e9SAndroid Build Coastguard Worker                    device = torch.device("cuda:0")
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker                    def maybe_wrap(m):
1265*da0073e9SAndroid Build Coastguard Worker                        return torch.nn.DataParallel(m, [0, 1])
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker                else:
1268*da0073e9SAndroid Build Coastguard Worker                    device = torch.device("cpu")
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker                    def maybe_wrap(m):
1271*da0073e9SAndroid Build Coastguard Worker                        return m
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker                for requires_grad in (True, False):
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker                    def get_modules():
1276*da0073e9SAndroid Build Coastguard Worker                        m = nn.Linear(3, 4).to(device)
1277*da0073e9SAndroid Build Coastguard Worker                        m.weight.requires_grad_(requires_grad)
1278*da0073e9SAndroid Build Coastguard Worker                        m = torch.nn.utils.parametrizations.spectral_norm(m)
1279*da0073e9SAndroid Build Coastguard Worker                        wrapped_m = maybe_wrap(m)
1280*da0073e9SAndroid Build Coastguard Worker                        spectral_norm_m = m.parametrizations.weight[0]
1281*da0073e9SAndroid Build Coastguard Worker                        return m, wrapped_m, spectral_norm_m
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(2, 3, device=device)
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker                    m, wrapped_m, spectral_norm_m = get_modules()
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(hasattr(spectral_norm_m, "_u"))
1288*da0073e9SAndroid Build Coastguard Worker                    u0 = spectral_norm_m._u.clone()
1289*da0073e9SAndroid Build Coastguard Worker                    v0 = spectral_norm_m._v.clone()
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker                    # TEST TRAINING BEHAVIOR
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker                    # We perform GD first to modify the initial matrix
1294*da0073e9SAndroid Build Coastguard Worker                    opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1)
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker                    opt.zero_grad()
1297*da0073e9SAndroid Build Coastguard Worker                    wrapped_m(input).sum().backward()
1298*da0073e9SAndroid Build Coastguard Worker                    opt.step()
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker                    out = wrapped_m(input)
1301*da0073e9SAndroid Build Coastguard Worker                    if requires_grad:
1302*da0073e9SAndroid Build Coastguard Worker                        # run forward again and assert that u and v are updated
1303*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(u0, spectral_norm_m._u)
1304*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(v0, spectral_norm_m._v)
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker                    # assert that backprop reaches original weight
1307*da0073e9SAndroid Build Coastguard Worker                    # can't use gradcheck because the function changes as we
1308*da0073e9SAndroid Build Coastguard Worker                    # activate through it in training mode
1309*da0073e9SAndroid Build Coastguard Worker                    if requires_grad:
1310*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.grad(
1311*da0073e9SAndroid Build Coastguard Worker                            out.sum(), m.parametrizations.weight.original
1312*da0073e9SAndroid Build Coastguard Worker                        )
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker                    # test backward works with multiple forwards
1315*da0073e9SAndroid Build Coastguard Worker                    # it uses training mode so we need to reset `u` and `v` vectors
1316*da0073e9SAndroid Build Coastguard Worker                    # to same value at beginning for finite difference test to pass
1317*da0073e9SAndroid Build Coastguard Worker                    saved_u = spectral_norm_m._u.clone()
1318*da0073e9SAndroid Build Coastguard Worker                    saved_v = spectral_norm_m._v.clone()
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker                    def fn(input):
1321*da0073e9SAndroid Build Coastguard Worker                        spectral_norm_m._u.data.copy_(saved_u)
1322*da0073e9SAndroid Build Coastguard Worker                        spectral_norm_m._v.data.copy_(saved_v)
1323*da0073e9SAndroid Build Coastguard Worker                        out0 = wrapped_m(input)
1324*da0073e9SAndroid Build Coastguard Worker                        out1 = wrapped_m(input)
1325*da0073e9SAndroid Build Coastguard Worker                        return out0 + out1
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker                    # Make sure we can compute gradients wrt to all the parameters in the case
1328*da0073e9SAndroid Build Coastguard Worker                    # of double forward
1329*da0073e9SAndroid Build Coastguard Worker                    fn(input.clone().requires_grad_()).sum().backward()
1330*da0073e9SAndroid Build Coastguard Worker                    gradcheck(
1331*da0073e9SAndroid Build Coastguard Worker                        fn, (input.clone().requires_grad_(),), check_batched_grad=False
1332*da0073e9SAndroid Build Coastguard Worker                    )
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker                    # test removing
1335*da0073e9SAndroid Build Coastguard Worker                    # spectral norm module needs to be in eval mode if we'd like to
1336*da0073e9SAndroid Build Coastguard Worker                    # avoid doing another power iteration
1337*da0073e9SAndroid Build Coastguard Worker                    m, wrapped_m, _ = get_modules()
1338*da0073e9SAndroid Build Coastguard Worker                    pre_remove_out = wrapped_m(input)
1339*da0073e9SAndroid Build Coastguard Worker                    if get_swap_module_params_on_conversion():
1340*da0073e9SAndroid Build Coastguard Worker                        # When using the swap_tensors path, this is needed so that the autograd
1341*da0073e9SAndroid Build Coastguard Worker                        # graph is not alive anymore.
1342*da0073e9SAndroid Build Coastguard Worker                        pre_remove_out_ref = pre_remove_out.detach()
1343*da0073e9SAndroid Build Coastguard Worker                        del pre_remove_out
1344*da0073e9SAndroid Build Coastguard Worker                    else:
1345*da0073e9SAndroid Build Coastguard Worker                        pre_remove_out_ref = pre_remove_out
1346*da0073e9SAndroid Build Coastguard Worker                    m.eval()
1347*da0073e9SAndroid Build Coastguard Worker                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1348*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker                    torch.nn.utils.parametrizations.spectral_norm(m)
1351*da0073e9SAndroid Build Coastguard Worker                    for _ in range(3):
1352*da0073e9SAndroid Build Coastguard Worker                        pre_remove_out = wrapped_m(input)
1353*da0073e9SAndroid Build Coastguard Worker                    if get_swap_module_params_on_conversion():
1354*da0073e9SAndroid Build Coastguard Worker                        # When using the swap_tensors path, this is needed so that the autograd
1355*da0073e9SAndroid Build Coastguard Worker                        # graph is not alive anymore.
1356*da0073e9SAndroid Build Coastguard Worker                        pre_remove_out_ref = pre_remove_out.detach()
1357*da0073e9SAndroid Build Coastguard Worker                        del pre_remove_out
1358*da0073e9SAndroid Build Coastguard Worker                    else:
1359*da0073e9SAndroid Build Coastguard Worker                        pre_remove_out_ref = pre_remove_out
1360*da0073e9SAndroid Build Coastguard Worker                    m.eval()
1361*da0073e9SAndroid Build Coastguard Worker                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1362*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker                    # TEST EVAL BEHAVIOR
1365*da0073e9SAndroid Build Coastguard Worker                    m, wrapped_m, spectral_norm_m = get_modules()
1366*da0073e9SAndroid Build Coastguard Worker                    wrapped_m(input)
1367*da0073e9SAndroid Build Coastguard Worker                    last_train_out = wrapped_m(input)
1368*da0073e9SAndroid Build Coastguard Worker                    last_train_u = spectral_norm_m._u.clone()
1369*da0073e9SAndroid Build Coastguard Worker                    last_train_v = spectral_norm_m._v.clone()
1370*da0073e9SAndroid Build Coastguard Worker                    wrapped_m.zero_grad()
1371*da0073e9SAndroid Build Coastguard Worker                    wrapped_m.eval()
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker                    eval_out0 = wrapped_m(input)
1374*da0073e9SAndroid Build Coastguard Worker                    # assert eval gives same result as last training iteration
1375*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(eval_out0, last_train_out)
1376*da0073e9SAndroid Build Coastguard Worker                    # assert doing more iteartion in eval don't change things
1377*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(eval_out0, wrapped_m(input))
1378*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(last_train_u, spectral_norm_m._u)
1379*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(last_train_v, spectral_norm_m._v)
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker                    # FIXME: the code below is flaky when executed with DataParallel
1382*da0073e9SAndroid Build Coastguard Worker                    # see https://github.com/pytorch/pytorch/issues/13818
1383*da0073e9SAndroid Build Coastguard Worker                    if apply_dp:
1384*da0073e9SAndroid Build Coastguard Worker                        continue
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker                    # test backward works with multiple forwards in mixed training
1387*da0073e9SAndroid Build Coastguard Worker                    # and eval modes
1388*da0073e9SAndroid Build Coastguard Worker                    # it uses training mode so we need to reset `u` and `v` vectors
1389*da0073e9SAndroid Build Coastguard Worker                    # to same value at beginning for finite difference test to pass
1390*da0073e9SAndroid Build Coastguard Worker                    saved_u = spectral_norm_m._u.clone()
1391*da0073e9SAndroid Build Coastguard Worker                    saved_v = spectral_norm_m._v.clone()
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker                    def fn(input):
1394*da0073e9SAndroid Build Coastguard Worker                        spectral_norm_m._u.data.copy_(saved_u)
1395*da0073e9SAndroid Build Coastguard Worker                        spectral_norm_m._v.data.copy_(saved_v)
1396*da0073e9SAndroid Build Coastguard Worker                        wrapped_m.train()
1397*da0073e9SAndroid Build Coastguard Worker                        out0 = wrapped_m(input)
1398*da0073e9SAndroid Build Coastguard Worker                        wrapped_m.eval()
1399*da0073e9SAndroid Build Coastguard Worker                        out1 = wrapped_m(input)
1400*da0073e9SAndroid Build Coastguard Worker                        wrapped_m.train()
1401*da0073e9SAndroid Build Coastguard Worker                        out2 = wrapped_m(input)
1402*da0073e9SAndroid Build Coastguard Worker                        wrapped_m.eval()
1403*da0073e9SAndroid Build Coastguard Worker                        out3 = wrapped_m(input)
1404*da0073e9SAndroid Build Coastguard Worker                        return out0 + out1 + out2 + out3
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker                    gradcheck(fn, (input.clone().requires_grad_(),))
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker                    # assert that backprop reaches weight_orig in eval
1409*da0073e9SAndroid Build Coastguard Worker                    if requires_grad:
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker                        def fn(weight):
1412*da0073e9SAndroid Build Coastguard Worker                            return wrapped_m(input)
1413*da0073e9SAndroid Build Coastguard Worker
1414*da0073e9SAndroid Build Coastguard Worker                        gradcheck(fn, (m.parametrizations.weight.original,))
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker    def test_register_parametrization_no_grad(self):
1417*da0073e9SAndroid Build Coastguard Worker        r"""Test that it is possible to register a parametrization without gradient"""
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker        class SplitAndCat(nn.Module):
1420*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, x):
1421*da0073e9SAndroid Build Coastguard Worker                # split the tensor in two halfs
1422*da0073e9SAndroid Build Coastguard Worker                return torch.split(x, x.shape[1] // 2)
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker            def forward(self, x0, x1):
1425*da0073e9SAndroid Build Coastguard Worker                return torch.cat([x0, x1])
1426*da0073e9SAndroid Build Coastguard Worker
1427*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(8, 8)
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker        model.weight.requires_grad = False
1430*da0073e9SAndroid Build Coastguard Worker        parametrize.register_parametrization(model, "weight", SplitAndCat())
1431*da0073e9SAndroid Build Coastguard Worker        # making sure the parameterized and decomposed Tensors both have requires_grad == False
1432*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(model.weight.requires_grad)
1433*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(model.parametrizations.weight.original0.requires_grad)
1434*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(model.parametrizations.weight.original1.requires_grad)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1437*da0073e9SAndroid Build Coastguard Worker    def test_new_spectral_norm_load_state_dict(self):
1438*da0073e9SAndroid Build Coastguard Worker        for activate_times in (0, 3):
1439*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(2, 3)
1440*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(3, 5)
1441*da0073e9SAndroid Build Coastguard Worker            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1442*da0073e9SAndroid Build Coastguard Worker            snm.train()
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker            for _ in range(activate_times):
1445*da0073e9SAndroid Build Coastguard Worker                snm(inp)
1446*da0073e9SAndroid Build Coastguard Worker
1447*da0073e9SAndroid Build Coastguard Worker            state_dict = deepcopy(snm.state_dict())
1448*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1449*da0073e9SAndroid Build Coastguard Worker                {
1450*da0073e9SAndroid Build Coastguard Worker                    "parametrizations.weight.original",
1451*da0073e9SAndroid Build Coastguard Worker                    "bias",
1452*da0073e9SAndroid Build Coastguard Worker                    "parametrizations.weight.0._v",
1453*da0073e9SAndroid Build Coastguard Worker                    "parametrizations.weight.0._u",
1454*da0073e9SAndroid Build Coastguard Worker                },
1455*da0073e9SAndroid Build Coastguard Worker                set(state_dict.keys()),
1456*da0073e9SAndroid Build Coastguard Worker            )
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker            # test that non-strict loading works
1459*da0073e9SAndroid Build Coastguard Worker            non_strict_state_dict = deepcopy(state_dict)
1460*da0073e9SAndroid Build Coastguard Worker            non_strict_state_dict["nonsense"] = "nonsense"
1461*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1462*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'
1463*da0073e9SAndroid Build Coastguard Worker            ):
1464*da0073e9SAndroid Build Coastguard Worker                snm.load_state_dict(non_strict_state_dict, strict=True)
1465*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1466*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict["parametrizations.weight.original"]
1467*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1468*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict["parametrizations.weight.0._u"]
1469*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1470*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict["parametrizations.weight.0._v"]
1471*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1472*da0073e9SAndroid Build Coastguard Worker            non_strict_state_dict[
1473*da0073e9SAndroid Build Coastguard Worker                "weight"
1474*da0073e9SAndroid Build Coastguard Worker            ] = snm.weight.detach().clone()  # set W as a buffer
1475*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1476*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict._metadata[
1477*da0073e9SAndroid Build Coastguard Worker                "parametrizations.weight.0"
1478*da0073e9SAndroid Build Coastguard Worker            ]  # remove metadata info
1479*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1480*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict["weight"]  # remove W buffer
1481*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1482*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict["bias"]
1483*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker            # normal state_dict
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker            # test that re-wrapping does not matter
1488*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1489*da0073e9SAndroid Build Coastguard Worker            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(state_dict)
1492*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1493*da0073e9SAndroid Build Coastguard Worker                snm.eval()
1494*da0073e9SAndroid Build Coastguard Worker                out0_eval = snm(inp)
1495*da0073e9SAndroid Build Coastguard Worker                snm.train()
1496*da0073e9SAndroid Build Coastguard Worker                out1_train = snm(inp)
1497*da0073e9SAndroid Build Coastguard Worker                out2_train = snm(inp)
1498*da0073e9SAndroid Build Coastguard Worker                snm.eval()
1499*da0073e9SAndroid Build Coastguard Worker                out3_eval = snm(inp)
1500*da0073e9SAndroid Build Coastguard Worker
1501*da0073e9SAndroid Build Coastguard Worker            # test that re-wrapping does not matter
1502*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1503*da0073e9SAndroid Build Coastguard Worker            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1504*da0073e9SAndroid Build Coastguard Worker
1505*da0073e9SAndroid Build Coastguard Worker            # Test normal loading
1506*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(state_dict)
1507*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1508*da0073e9SAndroid Build Coastguard Worker                snm.eval()
1509*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out0_eval, snm(inp))
1510*da0073e9SAndroid Build Coastguard Worker                snm.train()
1511*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out1_train, snm(inp))
1512*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out2_train, snm(inp))
1513*da0073e9SAndroid Build Coastguard Worker                snm.eval()
1514*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out3_eval, snm(inp))
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1517*da0073e9SAndroid Build Coastguard Worker    def test_new_spectral_norm_dim(self):
1518*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 3, 10, 12)
1519*da0073e9SAndroid Build Coastguard Worker        m = nn.ConvTranspose2d(3, 4, (5, 6))
1520*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.parametrizations.spectral_norm(m)
1521*da0073e9SAndroid Build Coastguard Worker        snm = m.parametrizations.weight[0]
1522*da0073e9SAndroid Build Coastguard Worker        # this should not run into incompatible shapes
1523*da0073e9SAndroid Build Coastguard Worker        x = m(inp)
1524*da0073e9SAndroid Build Coastguard Worker        # check that u refers to the same dimension
1525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1526*da0073e9SAndroid Build Coastguard Worker            snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape
1527*da0073e9SAndroid Build Coastguard Worker        )
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1530*da0073e9SAndroid Build Coastguard Worker    def test_new_spectral_norm_forward(self):
1531*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 5)
1532*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 7)
1533*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.parametrizations.spectral_norm(m)
1534*da0073e9SAndroid Build Coastguard Worker        snm = m.parametrizations.weight[0]
1535*da0073e9SAndroid Build Coastguard Worker        # naive forward
1536*da0073e9SAndroid Build Coastguard Worker        _weight = m.parametrizations.weight.original
1537*da0073e9SAndroid Build Coastguard Worker        _bias, _v = m.bias, snm._v
1538*da0073e9SAndroid Build Coastguard Worker        _weight_mat = _weight.view(_weight.size(0), -1)
1539*da0073e9SAndroid Build Coastguard Worker        _u = torch.mv(_weight_mat, _v)
1540*da0073e9SAndroid Build Coastguard Worker        _u = F.normalize(_u, dim=0, eps=1e-12)
1541*da0073e9SAndroid Build Coastguard Worker        _v = torch.mv(_weight_mat.t(), _u)
1542*da0073e9SAndroid Build Coastguard Worker        _v = F.normalize(_v, dim=0, eps=1e-12)
1543*da0073e9SAndroid Build Coastguard Worker        _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
1544*da0073e9SAndroid Build Coastguard Worker        out_hat = torch.nn.functional.linear(input, _weight, _bias)
1545*da0073e9SAndroid Build Coastguard Worker        expect_out = m(input)
1546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect_out, out_hat)
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1549*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test does not work with TorchDynamo")
1550*da0073e9SAndroid Build Coastguard Worker    def test_new_spectral_norm_value(self):
1551*da0073e9SAndroid Build Coastguard Worker        # a test that the spectral norm (= top singular value)
1552*da0073e9SAndroid Build Coastguard Worker        # is in fact properly calculated, using example of a simple diagonal matrix.
1553*da0073e9SAndroid Build Coastguard Worker        for dtype in (torch.float, torch.cfloat):
1554*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(2, 2, dtype=dtype)
1555*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1556*da0073e9SAndroid Build Coastguard Worker                # set weight to be diagonal
1557*da0073e9SAndroid Build Coastguard Worker                x = torch.diagonal(m.weight)
1558*da0073e9SAndroid Build Coastguard Worker                m.weight = nn.Parameter(torch.diag(x))
1559*da0073e9SAndroid Build Coastguard Worker                torch.nn.utils.parametrizations.spectral_norm(m)
1560*da0073e9SAndroid Build Coastguard Worker                # weights should be rescaled by spectral norm, (i.e., largest diagonal element in norm)
1561*da0073e9SAndroid Build Coastguard Worker                expected = torch.diag(x / x.abs().max())
1562*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(m.weight.data, expected)
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
1565*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1566*da0073e9SAndroid Build Coastguard Worker    def test_orthogonal_parametrization(self):
1567*da0073e9SAndroid Build Coastguard Worker        # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker        def assert_is_orthogonal(X):
1570*da0073e9SAndroid Build Coastguard Worker            n, k = X.size(-2), X.size(-1)
1571*da0073e9SAndroid Build Coastguard Worker            if n < k:
1572*da0073e9SAndroid Build Coastguard Worker                X = X.mT
1573*da0073e9SAndroid Build Coastguard Worker                n, k = k, n
1574*da0073e9SAndroid Build Coastguard Worker            Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(
1575*da0073e9SAndroid Build Coastguard Worker                *(X.size()[:-2]), k, k
1576*da0073e9SAndroid Build Coastguard Worker            )
1577*da0073e9SAndroid Build Coastguard Worker            eps = 10 * n * torch.finfo(X.dtype).eps
1578*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.0)
1579*da0073e9SAndroid Build Coastguard Worker
1580*da0073e9SAndroid Build Coastguard Worker        def assert_weight_allclose_Q(weight, W):
1581*da0073e9SAndroid Build Coastguard Worker            # Test that weight is equal to the Q part of the QR decomposition of W
1582*da0073e9SAndroid Build Coastguard Worker            # (or of its transpose if the matrix is wide)
1583*da0073e9SAndroid Build Coastguard Worker            wide_matrix = W.size(-2) < W.size(-1)
1584*da0073e9SAndroid Build Coastguard Worker            if wide_matrix:
1585*da0073e9SAndroid Build Coastguard Worker                W = W.mT
1586*da0073e9SAndroid Build Coastguard Worker            Q, R = torch.linalg.qr(W)
1587*da0073e9SAndroid Build Coastguard Worker            Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
1588*da0073e9SAndroid Build Coastguard Worker            if wide_matrix:
1589*da0073e9SAndroid Build Coastguard Worker                Q = Q.mT
1590*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.0)
1591*da0073e9SAndroid Build Coastguard Worker
1592*da0073e9SAndroid Build Coastguard Worker        for shape, dtype, use_linear in product(
1593*da0073e9SAndroid Build Coastguard Worker            ((4, 4), (5, 3), (3, 5)),  # square/ tall / wide
1594*da0073e9SAndroid Build Coastguard Worker            (torch.float32, torch.complex64),
1595*da0073e9SAndroid Build Coastguard Worker            (True, False),
1596*da0073e9SAndroid Build Coastguard Worker        ):
1597*da0073e9SAndroid Build Coastguard Worker            # Conv2d does not support complex yet
1598*da0073e9SAndroid Build Coastguard Worker            if not use_linear:
1599*da0073e9SAndroid Build Coastguard Worker                continue
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker            if use_linear:
1602*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(3, shape[0], dtype=dtype)
1603*da0073e9SAndroid Build Coastguard Worker            else:
1604*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker            for parametrization, use_trivialization in product(
1607*da0073e9SAndroid Build Coastguard Worker                ("matrix_exp", "cayley", "householder"), (False, True)
1608*da0073e9SAndroid Build Coastguard Worker            ):
1609*da0073e9SAndroid Build Coastguard Worker                # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False
1610*da0073e9SAndroid Build Coastguard Worker                # See Note [right_inverse expm cayley]
1611*da0073e9SAndroid Build Coastguard Worker                can_initialize = use_trivialization or parametrization == "householder"
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker                # We generate them every time to always start with fresh weights
1614*da0073e9SAndroid Build Coastguard Worker                if use_linear:
1615*da0073e9SAndroid Build Coastguard Worker                    m = nn.Linear(*shape, dtype=dtype)
1616*da0073e9SAndroid Build Coastguard Worker                else:
1617*da0073e9SAndroid Build Coastguard Worker                    m = nn.Conv2d(2, 3, shape, dtype=dtype)
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker                # We do not support householder for complex inputs
1620*da0073e9SAndroid Build Coastguard Worker                # See Note [Householder complex]
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker                # When using the swap_tensors path, this is needed so that the autograd
1623*da0073e9SAndroid Build Coastguard Worker                # graph is not alive anymore.
1624*da0073e9SAndroid Build Coastguard Worker                if get_swap_module_params_on_conversion():
1625*da0073e9SAndroid Build Coastguard Worker                    w_init = m.weight.clone().detach()
1626*da0073e9SAndroid Build Coastguard Worker                else:
1627*da0073e9SAndroid Build Coastguard Worker                    w_init = m.weight.clone()
1628*da0073e9SAndroid Build Coastguard Worker                if parametrization == "householder" and m.weight.is_complex():
1629*da0073e9SAndroid Build Coastguard Worker                    msg = "householder parametrization does not support complex tensors"
1630*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(ValueError, msg):
1631*da0073e9SAndroid Build Coastguard Worker                        torch.nn.utils.parametrizations.orthogonal(
1632*da0073e9SAndroid Build Coastguard Worker                            m,
1633*da0073e9SAndroid Build Coastguard Worker                            "weight",
1634*da0073e9SAndroid Build Coastguard Worker                            parametrization,
1635*da0073e9SAndroid Build Coastguard Worker                            use_trivialization=use_trivialization,
1636*da0073e9SAndroid Build Coastguard Worker                        )
1637*da0073e9SAndroid Build Coastguard Worker                    continue
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker                wide_matrix = w_init.size(-2) < w_init.size(-1)
1640*da0073e9SAndroid Build Coastguard Worker                torch.nn.utils.parametrizations.orthogonal(
1641*da0073e9SAndroid Build Coastguard Worker                    m, "weight", parametrization, use_trivialization=use_trivialization
1642*da0073e9SAndroid Build Coastguard Worker                )
1643*da0073e9SAndroid Build Coastguard Worker                # Forwards works as expected
1644*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(w_init.shape, m.weight.shape)
1645*da0073e9SAndroid Build Coastguard Worker                assert_is_orthogonal(m.weight)
1646*da0073e9SAndroid Build Coastguard Worker                if can_initialize:
1647*da0073e9SAndroid Build Coastguard Worker                    assert_weight_allclose_Q(m.weight, w_init)
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker                # Intializing with a given orthogonal matrix works
1650*da0073e9SAndroid Build Coastguard Worker                X = torch.randn_like(m.weight)
1651*da0073e9SAndroid Build Coastguard Worker                if wide_matrix:
1652*da0073e9SAndroid Build Coastguard Worker                    X = X.mT
1653*da0073e9SAndroid Build Coastguard Worker                w_new = torch.linalg.qr(X).Q
1654*da0073e9SAndroid Build Coastguard Worker                if wide_matrix:
1655*da0073e9SAndroid Build Coastguard Worker                    w_new = w_new.mT
1656*da0073e9SAndroid Build Coastguard Worker                if can_initialize:
1657*da0073e9SAndroid Build Coastguard Worker                    m.weight = w_new
1658*da0073e9SAndroid Build Coastguard Worker                    torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.0)
1659*da0073e9SAndroid Build Coastguard Worker                else:
1660*da0073e9SAndroid Build Coastguard Worker                    msg = (
1661*da0073e9SAndroid Build Coastguard Worker                        "assign to the matrix exponential or the Cayley parametrization"
1662*da0073e9SAndroid Build Coastguard Worker                    )
1663*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(NotImplementedError, msg):
1664*da0073e9SAndroid Build Coastguard Worker                        m.weight = w_new
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker                # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix
1667*da0073e9SAndroid Build Coastguard Worker                w_new = torch.randn_like(m.weight)
1668*da0073e9SAndroid Build Coastguard Worker                if can_initialize:
1669*da0073e9SAndroid Build Coastguard Worker                    m.weight = w_new
1670*da0073e9SAndroid Build Coastguard Worker                    assert_weight_allclose_Q(m.weight, w_new)
1671*da0073e9SAndroid Build Coastguard Worker                else:
1672*da0073e9SAndroid Build Coastguard Worker                    msg = (
1673*da0073e9SAndroid Build Coastguard Worker                        "assign to the matrix exponential or the Cayley parametrization"
1674*da0073e9SAndroid Build Coastguard Worker                    )
1675*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(NotImplementedError, msg):
1676*da0073e9SAndroid Build Coastguard Worker                        m.weight = w_new
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker                opt = torch.optim.SGD(m.parameters(), lr=0.1)
1679*da0073e9SAndroid Build Coastguard Worker                for _ in range(2):
1680*da0073e9SAndroid Build Coastguard Worker                    opt.zero_grad()
1681*da0073e9SAndroid Build Coastguard Worker                    m(input).norm().backward()
1682*da0073e9SAndroid Build Coastguard Worker                    grad = m.parametrizations.weight.original.grad
1683*da0073e9SAndroid Build Coastguard Worker                    self.assertIsNotNone(grad)
1684*da0073e9SAndroid Build Coastguard Worker                    # We do not update the upper triangular part of the matrix if tall tril if wide
1685*da0073e9SAndroid Build Coastguard Worker                    if grad.size(-2) >= grad.size(-1):
1686*da0073e9SAndroid Build Coastguard Worker                        zeros_grad = grad.triu(1)
1687*da0073e9SAndroid Build Coastguard Worker                    else:
1688*da0073e9SAndroid Build Coastguard Worker                        zeros_grad = grad.tril(-1)
1689*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
1690*da0073e9SAndroid Build Coastguard Worker                    # The gradient in the diagonal can only be imaginary because a skew-Hermitian
1691*da0073e9SAndroid Build Coastguard Worker                    # matrix has imaginary diagonal
1692*da0073e9SAndroid Build Coastguard Worker                    diag_grad = grad.diagonal(dim1=-2, dim2=-1)
1693*da0073e9SAndroid Build Coastguard Worker                    if grad.is_complex():
1694*da0073e9SAndroid Build Coastguard Worker                        diag_grad = diag_grad.real
1695*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
1696*da0073e9SAndroid Build Coastguard Worker                    opt.step()
1697*da0073e9SAndroid Build Coastguard Worker                    assert_is_orthogonal(m.weight)
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
1700*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1701*da0073e9SAndroid Build Coastguard Worker    def test_orthogonal_errors(self):
1702*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(3, 4)
1703*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "has to be one of"):
1704*da0073e9SAndroid Build Coastguard Worker            torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a matrix"):
1707*da0073e9SAndroid Build Coastguard Worker            torch.nn.utils.parametrizations.orthogonal(m, "bias")
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker        torch.nn.utils.parametrizations.orthogonal(m, "weight")
1710*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "matrices of shape"):
1711*da0073e9SAndroid Build Coastguard Worker            m.weight = torch.randn(5, 5)
1712*da0073e9SAndroid Build Coastguard Worker        torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1715*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm_state_dict_compat(self):
1716*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 5)
1717*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.weight_norm(m)
1718*da0073e9SAndroid Build Coastguard Worker        old_dict = m.state_dict()
1719*da0073e9SAndroid Build Coastguard Worker
1720*da0073e9SAndroid Build Coastguard Worker        m2 = nn.Linear(4, 5)
1721*da0073e9SAndroid Build Coastguard Worker        m2 = torch.nn.utils.parametrizations.weight_norm(m2)
1722*da0073e9SAndroid Build Coastguard Worker        m2.load_state_dict(old_dict)
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
1725*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(input), m2(input))
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1728*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm_pickle(self):
1729*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 5)
1730*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.parametrizations.weight_norm(m)
1731*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "state_dict"):
1732*da0073e9SAndroid Build Coastguard Worker            pickle.dumps(m)
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1735*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm_deepcopy(self):
1736*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 5)
1737*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.parametrizations.weight_norm(m)
1738*da0073e9SAndroid Build Coastguard Worker        m2 = deepcopy(m)
1739*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
1740*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(input), m2(input))
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker    @swap([True])
1743*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_parametrization(self):
1744*da0073e9SAndroid Build Coastguard Worker        class Subclassify(nn.Module):
1745*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
1746*da0073e9SAndroid Build Coastguard Worker                return TwoTensor(X, X)
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Worker        class UnSubclassify(nn.Module):
1749*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
1750*da0073e9SAndroid Build Coastguard Worker                return X.a
1751*da0073e9SAndroid Build Coastguard Worker
1752*da0073e9SAndroid Build Coastguard Worker        class IdentityWithRightInverse(nn.Module):
1753*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
1754*da0073e9SAndroid Build Coastguard Worker                return X
1755*da0073e9SAndroid Build Coastguard Worker
1756*da0073e9SAndroid Build Coastguard Worker            def right_inverse(self, X):
1757*da0073e9SAndroid Build Coastguard Worker                return TwoTensor(X, X)
1758*da0073e9SAndroid Build Coastguard Worker
1759*da0073e9SAndroid Build Coastguard Worker        def _check_parametrization(
1760*da0073e9SAndroid Build Coastguard Worker            parametrization,
1761*da0073e9SAndroid Build Coastguard Worker            type_before_registration,
1762*da0073e9SAndroid Build Coastguard Worker            type_after_registration,
1763*da0073e9SAndroid Build Coastguard Worker            leave_parametrized=False,
1764*da0073e9SAndroid Build Coastguard Worker            type_after_right_inverse=None,
1765*da0073e9SAndroid Build Coastguard Worker        ):
1766*da0073e9SAndroid Build Coastguard Worker            model = nn.Linear(2, 2)
1767*da0073e9SAndroid Build Coastguard Worker            buf = torch.randn(2, 2)
1768*da0073e9SAndroid Build Coastguard Worker            model.buf = torch.nn.Buffer(buf)
1769*da0073e9SAndroid Build Coastguard Worker            if (
1770*da0073e9SAndroid Build Coastguard Worker                type_before_registration == TwoTensor
1771*da0073e9SAndroid Build Coastguard Worker                and type_after_registration == Tensor
1772*da0073e9SAndroid Build Coastguard Worker            ):
1773*da0073e9SAndroid Build Coastguard Worker                model._apply(lambda t: TwoTensor(t, t))
1774*da0073e9SAndroid Build Coastguard Worker            initial_weight = model.weight.clone().detach()
1775*da0073e9SAndroid Build Coastguard Worker            initial_weight_id = id(model.weight)
1776*da0073e9SAndroid Build Coastguard Worker            initial_buf = model.buf.clone().detach()
1777*da0073e9SAndroid Build Coastguard Worker            initial_buf_id = id(model.buf)
1778*da0073e9SAndroid Build Coastguard Worker            type_original_weight = (
1779*da0073e9SAndroid Build Coastguard Worker                type_before_registration
1780*da0073e9SAndroid Build Coastguard Worker                if type_after_right_inverse is None
1781*da0073e9SAndroid Build Coastguard Worker                else type_after_right_inverse
1782*da0073e9SAndroid Build Coastguard Worker            )
1783*da0073e9SAndroid Build Coastguard Worker            type_original_buf = (
1784*da0073e9SAndroid Build Coastguard Worker                Tensor if type_original_weight is nn.Parameter else type_original_weight
1785*da0073e9SAndroid Build Coastguard Worker            )
1786*da0073e9SAndroid Build Coastguard Worker            type_after_removal_buf = (
1787*da0073e9SAndroid Build Coastguard Worker                type_after_registration if leave_parametrized else type_original_buf
1788*da0073e9SAndroid Build Coastguard Worker            )
1789*da0073e9SAndroid Build Coastguard Worker            if leave_parametrized:
1790*da0073e9SAndroid Build Coastguard Worker                if type_after_registration is Tensor:
1791*da0073e9SAndroid Build Coastguard Worker                    type_after_removal_weight = nn.Parameter
1792*da0073e9SAndroid Build Coastguard Worker                else:
1793*da0073e9SAndroid Build Coastguard Worker                    type_after_removal_weight = type_after_registration
1794*da0073e9SAndroid Build Coastguard Worker            else:
1795*da0073e9SAndroid Build Coastguard Worker                type_after_removal_weight = type_original_weight
1796*da0073e9SAndroid Build Coastguard Worker
1797*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model, "weight", parametrization())
1798*da0073e9SAndroid Build Coastguard Worker            parametrize.register_parametrization(model, "buf", parametrization())
1799*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(model, "parametrizations"))
1800*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(parametrize.is_parametrized(model))
1801*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(parametrize.is_parametrized(model, "bias"))
1802*da0073e9SAndroid Build Coastguard Worker            # checks for weight
1803*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(parametrize.is_parametrized(model, "weight"))
1804*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1805*da0073e9SAndroid Build Coastguard Worker                isinstance(model.parametrizations.weight.original, nn.Parameter)
1806*da0073e9SAndroid Build Coastguard Worker            )
1807*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1808*da0073e9SAndroid Build Coastguard Worker                type(model.parametrizations.weight.original) is type_original_weight
1809*da0073e9SAndroid Build Coastguard Worker            )
1810*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("weight", model._parameters)
1811*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(type(model.weight) is type_after_registration)
1812*da0073e9SAndroid Build Coastguard Worker            # checks for buf
1813*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(parametrize.is_parametrized(model, "buf"))
1814*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
1815*da0073e9SAndroid Build Coastguard Worker                isinstance(model.parametrizations.buf.original, nn.Parameter)
1816*da0073e9SAndroid Build Coastguard Worker            )
1817*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1818*da0073e9SAndroid Build Coastguard Worker                type(model.parametrizations.buf.original) is type_original_buf
1819*da0073e9SAndroid Build Coastguard Worker            )
1820*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(type(model.buf) is type_after_registration)
1821*da0073e9SAndroid Build Coastguard Worker            parametrize.remove_parametrizations(
1822*da0073e9SAndroid Build Coastguard Worker                model, "weight", leave_parametrized=leave_parametrized
1823*da0073e9SAndroid Build Coastguard Worker            )
1824*da0073e9SAndroid Build Coastguard Worker            parametrize.remove_parametrizations(
1825*da0073e9SAndroid Build Coastguard Worker                model, "buf", leave_parametrized=leave_parametrized
1826*da0073e9SAndroid Build Coastguard Worker            )
1827*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(hasattr(model, "parametrizations"))
1828*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(model.__class__, nn.Linear)
1829*da0073e9SAndroid Build Coastguard Worker            # checks for weight
1830*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(type(model.weight) is type_after_removal_weight)
1831*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(model.weight, nn.Parameter))
1832*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(id(model.weight), initial_weight_id)
1833*da0073e9SAndroid Build Coastguard Worker            # checks for buf
1834*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(type(model.buf) is type_after_removal_buf)
1835*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(isinstance(model.buf, nn.Parameter))
1836*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(id(model.buf), initial_buf_id)
1837*da0073e9SAndroid Build Coastguard Worker            if not leave_parametrized and type_after_right_inverse is None:
1838*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(model.weight, initial_weight)
1839*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(model.buf, initial_buf)
1840*da0073e9SAndroid Build Coastguard Worker
1841*da0073e9SAndroid Build Coastguard Worker        _check_parametrization(Subclassify, nn.Parameter, TwoTensor)
1842*da0073e9SAndroid Build Coastguard Worker        _check_parametrization(UnSubclassify, TwoTensor, Tensor)
1843*da0073e9SAndroid Build Coastguard Worker        _check_parametrization(
1844*da0073e9SAndroid Build Coastguard Worker            IdentityWithRightInverse,
1845*da0073e9SAndroid Build Coastguard Worker            nn.Parameter,
1846*da0073e9SAndroid Build Coastguard Worker            TwoTensor,
1847*da0073e9SAndroid Build Coastguard Worker            type_after_right_inverse=TwoTensor,
1848*da0073e9SAndroid Build Coastguard Worker        )
1849*da0073e9SAndroid Build Coastguard Worker        _check_parametrization(
1850*da0073e9SAndroid Build Coastguard Worker            Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True
1851*da0073e9SAndroid Build Coastguard Worker        )
1852*da0073e9SAndroid Build Coastguard Worker        _check_parametrization(
1853*da0073e9SAndroid Build Coastguard Worker            UnSubclassify, TwoTensor, Tensor, leave_parametrized=True
1854*da0073e9SAndroid Build Coastguard Worker        )
1855*da0073e9SAndroid Build Coastguard Worker        _check_parametrization(
1856*da0073e9SAndroid Build Coastguard Worker            IdentityWithRightInverse,
1857*da0073e9SAndroid Build Coastguard Worker            nn.Parameter,
1858*da0073e9SAndroid Build Coastguard Worker            TwoTensor,
1859*da0073e9SAndroid Build Coastguard Worker            leave_parametrized=True,
1860*da0073e9SAndroid Build Coastguard Worker            type_after_right_inverse=TwoTensor,
1861*da0073e9SAndroid Build Coastguard Worker        )
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker
1864*da0073e9SAndroid Build Coastguard Workerclass TestNNParametrizationDevice(NNTestCase):
1865*da0073e9SAndroid Build Coastguard Worker    @swap([True, False])
1866*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm_parametrization(self, device):
1867*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.bfloat16]:
1868*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 4, dtype=dtype, device=device)
1869*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(4, 5, dtype=dtype, device=device)
1870*da0073e9SAndroid Build Coastguard Worker            expected_output = m(input)
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker            # add weight normalization
1873*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrizations.weight_norm(m)
1874*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1875*da0073e9SAndroid Build Coastguard Worker                m.parametrizations.weight.original1.size(), m.weight.size()
1876*da0073e9SAndroid Build Coastguard Worker            )
1877*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
1878*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output)
1879*da0073e9SAndroid Build Coastguard Worker
1880*da0073e9SAndroid Build Coastguard Worker            # remove weight norm
1881*da0073e9SAndroid Build Coastguard Worker            torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1882*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(hasattr(m, "parametrizations"))
1883*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output)
1884*da0073e9SAndroid Build Coastguard Worker
1885*da0073e9SAndroid Build Coastguard Worker            # test with dim=1
1886*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
1887*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1888*da0073e9SAndroid Build Coastguard Worker                m.parametrizations.weight.original1.size(), m.weight.size()
1889*da0073e9SAndroid Build Coastguard Worker            )
1890*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
1891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output)
1892*da0073e9SAndroid Build Coastguard Worker
1893*da0073e9SAndroid Build Coastguard Worker            # test with dim=None
1894*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(4, 5, dtype=dtype, device=device)
1895*da0073e9SAndroid Build Coastguard Worker            expected_output = m(input)
1896*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
1897*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output)
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Workeronly_for = ("cpu", "cuda")
1901*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for)
1902*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNNParametrization)
1903*da0073e9SAndroid Build Coastguard Worker
1904*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1905*da0073e9SAndroid Build Coastguard Worker    run_tests()
1906