1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Workerimport pickle 3*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.init as init 10*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.parametrize as parametrize 11*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 12*da0073e9SAndroid Build Coastguard Workerfrom torch.__future__ import get_swap_module_params_on_conversion 13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Buffer, Parameter 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 18*da0073e9SAndroid Build Coastguard Worker gradcheck, 19*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 20*da0073e9SAndroid Build Coastguard Worker run_tests, 21*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 22*da0073e9SAndroid Build Coastguard Worker skipIfNoLapack, 23*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 24*da0073e9SAndroid Build Coastguard Worker swap, 25*da0073e9SAndroid Build Coastguard Worker TemporaryFileName, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerclass TestNNParametrization(NNTestCase): 31*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 32*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker # FIXME: Rewrite this test using functions not depending on LAPACK 35*da0073e9SAndroid Build Coastguard Worker # and remove the `@skipIfNoLapack` (see #70995) 36*da0073e9SAndroid Build Coastguard Worker # torch/nn/utils/parametrize 37*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 38*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 39*da0073e9SAndroid Build Coastguard Worker def test_register_and_remove_parametrization(self): 40*da0073e9SAndroid Build Coastguard Worker r"""Test that it is possible to add a few parametrizations 41*da0073e9SAndroid Build Coastguard Worker on a parameter or a buffer and that removing them restores the initial state 42*da0073e9SAndroid Build Coastguard Worker It also tests that backpropagating through them works as expected 43*da0073e9SAndroid Build Coastguard Worker """ 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker # Define a couple matrix parametrizations 46*da0073e9SAndroid Build Coastguard Worker class Skew(nn.Module): 47*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 48*da0073e9SAndroid Build Coastguard Worker X = X.tril(-1) 49*da0073e9SAndroid Build Coastguard Worker return X - X.T 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker class Orthogonal(nn.Module): 52*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 53*da0073e9SAndroid Build Coastguard Worker # Cayley map 54*da0073e9SAndroid Build Coastguard Worker # If X is skew-symmetric it returns an orthogonal matrix 55*da0073e9SAndroid Build Coastguard Worker Id = torch.eye(X.size(0), device=X.device) 56*da0073e9SAndroid Build Coastguard Worker # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous 57*da0073e9SAndroid Build Coastguard Worker # and autograd raises a performance warning. 58*da0073e9SAndroid Build Coastguard Worker # This happens when we remove the parametrization with leave_parametrized=True, 59*da0073e9SAndroid Build Coastguard Worker # which does a set_ with a non-contiguous tensor while the gradient is contiguous 60*da0073e9SAndroid Build Coastguard Worker return torch.linalg.solve(Id + X, Id - X).contiguous() 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker class Resize(nn.Module): 63*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 64*da0073e9SAndroid Build Coastguard Worker return X[[0]] 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker class NoResize(nn.Module): 67*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 68*da0073e9SAndroid Build Coastguard Worker return X 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker # Define a couple vector parametrizations 71*da0073e9SAndroid Build Coastguard Worker class FirstZero(nn.Module): 72*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 73*da0073e9SAndroid Build Coastguard Worker return torch.cat([x.new_zeros(1), x[1:]]) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker class LastZero(nn.Module): 76*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 77*da0073e9SAndroid Build Coastguard Worker return torch.cat([x[:-1], x.new_zeros(1)]) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(8, 8) 80*da0073e9SAndroid Build Coastguard Worker initial_weight_id = id(model.weight) 81*da0073e9SAndroid Build Coastguard Worker initial_bias_id = id(model.bias) 82*da0073e9SAndroid Build Coastguard Worker initial_model = deepcopy(model) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker # Test unsafe flag 85*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 86*da0073e9SAndroid Build Coastguard Worker ValueError, 87*da0073e9SAndroid Build Coastguard Worker "Registering a parametrization may not change the shape of the tensor", 88*da0073e9SAndroid Build Coastguard Worker ): 89*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization( 90*da0073e9SAndroid Build Coastguard Worker model, "weight", Resize() 91*da0073e9SAndroid Build Coastguard Worker ) # default unsafe = False 92*da0073e9SAndroid Build Coastguard Worker model(torch.ones(8, 8)) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker # One parametrization with unsafe=True 95*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) 96*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 97*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 98*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 99*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 100*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 101*da0073e9SAndroid Build Coastguard Worker self.assertTrue(model.weight.shape[0] == 1) 102*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 103*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, initial_model.weight) 105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) 106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker # Two parametrizations with unsafe=True 109*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) 110*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False) 111*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 112*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 113*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 114*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 115*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 116*da0073e9SAndroid Build Coastguard Worker self.assertTrue(model.weight.shape[0] == 1) 117*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 118*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, initial_model.weight) 120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) 121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker # Test unsafe flag doesn't change expected behavior 124*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew(), unsafe=True) 125*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 126*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 127*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 128*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 129*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 130*da0073e9SAndroid Build Coastguard Worker # Result should be skew-symmetric 131*da0073e9SAndroid Build Coastguard Worker A = model.weight 132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, -A.T) 133*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 134*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 135*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 136*da0073e9SAndroid Build Coastguard Worker del A 137*da0073e9SAndroid Build Coastguard Worker # Remove and check consistency 138*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 139*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, initial_model.weight) 141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker # Test one parametrization 145*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew()) 146*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 148*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 149*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 150*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 151*da0073e9SAndroid Build Coastguard Worker # Result should be skew-symmetric 152*da0073e9SAndroid Build Coastguard Worker A = model.weight 153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, -A.T) 154*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 155*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 156*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 157*da0073e9SAndroid Build Coastguard Worker del A 158*da0073e9SAndroid Build Coastguard Worker # Remove and check consistency 159*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 160*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, initial_model.weight) 162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) 163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker # Test two parametrizations at the same time and removing them 166*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew()) 167*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Orthogonal()) 168*da0073e9SAndroid Build Coastguard Worker # Result should be orthogonal 169*da0073e9SAndroid Build Coastguard Worker X = model.weight 170*da0073e9SAndroid Build Coastguard Worker Id = torch.eye(X.size(0), device=X.device) 171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(X.T @ X, Id) 172*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 173*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 174*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 175*da0073e9SAndroid Build Coastguard Worker del X 176*da0073e9SAndroid Build Coastguard Worker # Structure tests 177*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 178*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 179*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 180*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 181*da0073e9SAndroid Build Coastguard Worker self.assertIn("weight", model.parametrizations) 182*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 183*da0073e9SAndroid Build Coastguard Worker # Remove 184*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, initial_model.weight) 186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) 187*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker # Add everything 191*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew()) 192*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Orthogonal()) 193*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", FirstZero()) 194*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", LastZero()) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker # Basic tests 197*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 198*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 199*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "bias")) 200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[0].item(), 0.0) 201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[-1].item(), 0.0) 202*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 203*da0073e9SAndroid Build Coastguard Worker len(list(model.parameters())), 2 204*da0073e9SAndroid Build Coastguard Worker ) # Nothing weird has happpened 205*da0073e9SAndroid Build Coastguard Worker # Should not throw 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker sgd = torch.optim.SGD(model.parameters(), lr=0.01) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker weight_copy = model.weight.clone() 210*da0073e9SAndroid Build Coastguard Worker bias_copy = model.bias.clone() 211*da0073e9SAndroid Build Coastguard Worker sgd.zero_grad() 212*da0073e9SAndroid Build Coastguard Worker (model.weight.T @ model.bias).sum().backward() 213*da0073e9SAndroid Build Coastguard Worker sgd.step() 214*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.weight, weight_copy) 215*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.bias, bias_copy) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker # Remove first parametrization. 218*da0073e9SAndroid Build Coastguard Worker # Check that the model is still parametrized and so is the second parameter 219*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 220*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized 221*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 222*da0073e9SAndroid Build Coastguard Worker parametrize.is_parametrized(model, "weight") 223*da0073e9SAndroid Build Coastguard Worker ) # Parametrization removed 224*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 225*da0073e9SAndroid Build Coastguard Worker parametrize.is_parametrized(model, "bias") 226*da0073e9SAndroid Build Coastguard Worker ) # Still parametrized 227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[0].item(), 0.0) # Still parametrized 228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[-1].item(), 0.0) # Still parametrized 229*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.weight, initial_model.weight) # Has been updated 230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id 231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened 232*da0073e9SAndroid Build Coastguard Worker # Should not throw 233*da0073e9SAndroid Build Coastguard Worker weight_copy = model.weight.clone() 234*da0073e9SAndroid Build Coastguard Worker bias_copy = model.bias.clone() 235*da0073e9SAndroid Build Coastguard Worker sgd.zero_grad() 236*da0073e9SAndroid Build Coastguard Worker (model.weight.T @ model.bias).sum().backward() 237*da0073e9SAndroid Build Coastguard Worker sgd.step() 238*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.weight, weight_copy) 239*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.bias, bias_copy) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker # Remove the second parametrization. 242*da0073e9SAndroid Build Coastguard Worker # Check that the module is not parametrized 243*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "bias", leave_parametrized=False) 244*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model)) # Not parametrized 245*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.bias, initial_model.bias) # Has been updated 246*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.bias[0].item(), 0.0) # Not parametrized 247*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.bias[-1].item(), 0.0) # Not parametrized 248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.bias), initial_bias_id) # Keeps the same id 249*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 250*da0073e9SAndroid Build Coastguard Worker hasattr(model, "parametrizations") 251*da0073e9SAndroid Build Coastguard Worker ) # Not parametrized the module 252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) # Resores the previous class 253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happeed 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker # Should not throw things are updated 256*da0073e9SAndroid Build Coastguard Worker weight_copy = model.weight.clone() 257*da0073e9SAndroid Build Coastguard Worker bias_copy = model.bias.clone() 258*da0073e9SAndroid Build Coastguard Worker sgd.zero_grad() 259*da0073e9SAndroid Build Coastguard Worker (model.weight.T @ model.bias).sum().backward() 260*da0073e9SAndroid Build Coastguard Worker sgd.step() 261*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.weight, weight_copy) 262*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.bias, bias_copy) 263*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 264*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 265*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 266*da0073e9SAndroid Build Coastguard Worker del weight_copy, bias_copy 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker # Test leave_parametrized=True 269*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 270*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew()) 271*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Orthogonal()) 272*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations( 273*da0073e9SAndroid Build Coastguard Worker model, "weight", leave_parametrized=True 274*da0073e9SAndroid Build Coastguard Worker ) 275*da0073e9SAndroid Build Coastguard Worker # We didn't change the dtype nor had multiple inputs, so the id should be the same 276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.bias), initial_bias_id) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker # Should not throw. Things are updated 280*da0073e9SAndroid Build Coastguard Worker weight_copy = model.weight.clone() 281*da0073e9SAndroid Build Coastguard Worker bias_copy = model.bias.clone() 282*da0073e9SAndroid Build Coastguard Worker sgd.zero_grad() 283*da0073e9SAndroid Build Coastguard Worker (model.weight.T @ model.bias).sum().backward() 284*da0073e9SAndroid Build Coastguard Worker sgd.step() 285*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.weight, weight_copy) 286*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model.bias, bias_copy) 287*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 288*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 289*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 290*da0073e9SAndroid Build Coastguard Worker del weight_copy, bias_copy 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 293*da0073e9SAndroid Build Coastguard Worker def test_register_and_remove_nested_parametrization(self): 294*da0073e9SAndroid Build Coastguard Worker r"""Test that it is possible to nest the parametrizations 295*da0073e9SAndroid Build Coastguard Worker meaning that the original param is parametrized again 296*da0073e9SAndroid Build Coastguard Worker """ 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker class Skew(nn.Module): 299*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 300*da0073e9SAndroid Build Coastguard Worker X = X.tril(-1) 301*da0073e9SAndroid Build Coastguard Worker return X - X.T 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(8, 8) 304*da0073e9SAndroid Build Coastguard Worker # Add top level parametrization 305*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew()) 306*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 307*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 308*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 309*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 310*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 311*da0073e9SAndroid Build Coastguard Worker # Result should be skew-symmetric 312*da0073e9SAndroid Build Coastguard Worker A = model.weight 313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, -A.T) 314*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 315*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 316*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 317*da0073e9SAndroid Build Coastguard Worker del A 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker # Add nested parametrization 320*da0073e9SAndroid Build Coastguard Worker param_mod = model.parametrizations.weight 321*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(param_mod, "parametrizations")) 322*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(param_mod)) 323*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(param_mod, "original")) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(param_mod, "original", Skew()) 326*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(param_mod, "parametrizations")) 327*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(param_mod)) 328*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(param_mod, "original")) 329*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("original", param_mod._parameters) 330*da0073e9SAndroid Build Coastguard Worker # Result should be skew-symmetric 331*da0073e9SAndroid Build Coastguard Worker A = param_mod.original 332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, -A.T) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker # Remove nested param and check consistency 335*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations( 336*da0073e9SAndroid Build Coastguard Worker param_mod, "original", leave_parametrized=False 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(param_mod, "parametrizations")) 339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param_mod.__class__, parametrize.ParametrizationList) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker # Remove top level and check consistency 342*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 343*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 347*da0073e9SAndroid Build Coastguard Worker def test_register_and_remove_buffer_parametrization(self): 348*da0073e9SAndroid Build Coastguard Worker r"""Test that it is possible to add and remove parametrizations on buffers""" 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker # Define a couple vector parametrizations 351*da0073e9SAndroid Build Coastguard Worker class FirstZero(nn.Module): 352*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 353*da0073e9SAndroid Build Coastguard Worker return torch.cat([x.new_zeros(1), x[1:]]) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker class LastZero(nn.Module): 356*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 357*da0073e9SAndroid Build Coastguard Worker return torch.cat([x[:-1], x.new_zeros(1)]) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(8, 8) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker # Instantiate parametrizations on buffers. It should work as expected 362*da0073e9SAndroid Build Coastguard Worker delattr(model, "bias") 363*da0073e9SAndroid Build Coastguard Worker model.bias = Buffer(torch.ones(8)) 364*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", FirstZero()) 365*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", LastZero()) 366*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 367*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "bias")) 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[0].item(), 0.0) 369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[-1].item(), 0.0) 370*da0073e9SAndroid Build Coastguard Worker self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) 371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(model.parameters())), 1) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker # Remove parametrizations on buffers. It should work as expected 374*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "bias", leave_parametrized=True) 375*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model)) 376*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[0].item(), 0.0) 378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.bias[-1].item(), 0.0) 379*da0073e9SAndroid Build Coastguard Worker self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) 380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(model.parameters())), 1) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker # FIXME: Rewrite this test using functions not depending on LAPACK 383*da0073e9SAndroid Build Coastguard Worker # and remove the `@skipIfNoLapack` (see #70995) 384*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 385*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 386*da0073e9SAndroid Build Coastguard Worker def test_serialization_parametrization(self): 387*da0073e9SAndroid Build Coastguard Worker r"""Test that it is possible to serialize a parametrized model via state_dict""" 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker # A stateful parametrization 390*da0073e9SAndroid Build Coastguard Worker class Orthogonal(nn.Module): 391*da0073e9SAndroid Build Coastguard Worker def __init__(self, n): 392*da0073e9SAndroid Build Coastguard Worker super().__init__() 393*da0073e9SAndroid Build Coastguard Worker self.id = Buffer(torch.eye(n)) 394*da0073e9SAndroid Build Coastguard Worker self.B = Buffer(torch.empty(n, n)) 395*da0073e9SAndroid Build Coastguard Worker init.orthogonal_(self.B) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 398*da0073e9SAndroid Build Coastguard Worker A = X.triu(1) 399*da0073e9SAndroid Build Coastguard Worker A = A - A.T 400*da0073e9SAndroid Build Coastguard Worker return self.B @ torch.linalg.solve(self.id + A, self.id - A) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker def get_model(): 403*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 404*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(5, 5), 405*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 406*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(5, 1), 407*da0073e9SAndroid Build Coastguard Worker ) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model[0], "weight", Orthogonal(5)) 410*da0073e9SAndroid Build Coastguard Worker return model 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker model = get_model() 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker prev_weight = model[0].weight 415*da0073e9SAndroid Build Coastguard Worker prev_B = model[0].parametrizations.weight[0].B 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker new_model = get_model() 418*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 419*da0073e9SAndroid Build Coastguard Worker torch.save(model.state_dict(), fname) 420*da0073e9SAndroid Build Coastguard Worker new_model.load_state_dict(torch.load(fname)) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker # Integrity tests 423*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(new_model[0], "weight")) 424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prev_weight, new_model[0].weight) 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker # Trying to save the whole parametrized model raises 428*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "state_dict"): 429*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 430*da0073e9SAndroid Build Coastguard Worker torch.save(model, fname) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker # FIXME: Rewrite this test using functions not depending on LAPACK 433*da0073e9SAndroid Build Coastguard Worker # and remove the `@skipIfNoLapack` (see #70995) 434*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 435*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 436*da0073e9SAndroid Build Coastguard Worker def test_initialization_parametrization(self): 437*da0073e9SAndroid Build Coastguard Worker r"""Test that it is possible to initialize a parametrization when it 438*da0073e9SAndroid Build Coastguard Worker implements a `right_inverse` method 439*da0073e9SAndroid Build Coastguard Worker """ 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker class Skew(nn.Module): 442*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 443*da0073e9SAndroid Build Coastguard Worker A = X.triu(1) 444*da0073e9SAndroid Build Coastguard Worker return A - A.T 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker def is_skew(self, A): 447*da0073e9SAndroid Build Coastguard Worker return torch.allclose(A, -A.T, atol=1e-6) 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, X): 450*da0073e9SAndroid Build Coastguard Worker if not self.is_skew(X): 451*da0073e9SAndroid Build Coastguard Worker raise ValueError("The matrix is not skew-symmetric.") 452*da0073e9SAndroid Build Coastguard Worker return X.triu(1) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker # Implements a Cayley map where right_inverse is not quite the inverse of forward 455*da0073e9SAndroid Build Coastguard Worker class Orthogonal(nn.Module): 456*da0073e9SAndroid Build Coastguard Worker def __init__(self, n): 457*da0073e9SAndroid Build Coastguard Worker super().__init__() 458*da0073e9SAndroid Build Coastguard Worker self.B = Buffer(torch.eye(n)) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 461*da0073e9SAndroid Build Coastguard Worker Id = torch.eye(X.size(0)) 462*da0073e9SAndroid Build Coastguard Worker return self.B @ torch.linalg.solve(Id + X, Id - X) 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker def is_orthogonal(self, X): 465*da0073e9SAndroid Build Coastguard Worker Id = torch.eye(X.size(0)) 466*da0073e9SAndroid Build Coastguard Worker return torch.allclose(X.T @ X, Id, atol=1e-4) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, X): 469*da0073e9SAndroid Build Coastguard Worker if not self.is_orthogonal(X): 470*da0073e9SAndroid Build Coastguard Worker raise ValueError("The input is not orthogonal.") 471*da0073e9SAndroid Build Coastguard Worker # cayley(0) == Id, so B @ cayley(0) == B 472*da0073e9SAndroid Build Coastguard Worker self.B = X 473*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(X) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker N = 5 476*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(N, N) 477*da0073e9SAndroid Build Coastguard Worker # Register the skew-symmetric constraint. The result is now skew-symmetric 478*da0073e9SAndroid Build Coastguard Worker skew = Skew() 479*da0073e9SAndroid Build Coastguard Worker # Make the weight skew-symmetric before registering the parametrization 480*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 481*da0073e9SAndroid Build Coastguard Worker model.weight.set_(skew(model.weight)) 482*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", skew) 483*da0073e9SAndroid Build Coastguard Worker X = torch.rand(N, N) 484*da0073e9SAndroid Build Coastguard Worker # X is not skew-symmetric, so it throws an error 485*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 486*da0073e9SAndroid Build Coastguard Worker model.weight = X 487*da0073e9SAndroid Build Coastguard Worker # Make X skew-symmetric 488*da0073e9SAndroid Build Coastguard Worker X = X - X.T 489*da0073e9SAndroid Build Coastguard Worker model.weight = X 490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.parametrizations.weight.original, X.triu(1)) 491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, X) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker # Having several parametrizations registered should work in the same way 494*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Orthogonal(N)) 495*da0073e9SAndroid Build Coastguard Worker # Register now the Cayley map. The result is now orthogonal 496*da0073e9SAndroid Build Coastguard Worker X = torch.rand(N, N) 497*da0073e9SAndroid Build Coastguard Worker # X is not orthogonal, so it throws an error 498*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 499*da0073e9SAndroid Build Coastguard Worker model.weight = X 500*da0073e9SAndroid Build Coastguard Worker init.orthogonal_(X) 501*da0073e9SAndroid Build Coastguard Worker model.weight = X 502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, X) 503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 506*da0073e9SAndroid Build Coastguard Worker def test_errors_unparametrized_tensor_parametrization(self): 507*da0073e9SAndroid Build Coastguard Worker # Test errors when registering a parametrization on an unparametrized tensor 508*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(3, 4) 509*da0073e9SAndroid Build Coastguard Worker weight_init = module.weight.clone() 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker class Identity(nn.Module): 512*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 513*da0073e9SAndroid Build Coastguard Worker return x 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker # Register a parametrization on a non-existing parameter throws 516*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "does not have a parameter"): 517*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "foo", Identity()) 518*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker # Removing parametrizations from an unparametrized tensor throws 521*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "does not have a parametrization"): 522*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(module, "bias") 523*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker # A correct parametrization with several outputs 526*da0073e9SAndroid Build Coastguard Worker class Sum(nn.Module): 527*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 528*da0073e9SAndroid Build Coastguard Worker return x + y 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, z): 531*da0073e9SAndroid Build Coastguard Worker return z, torch.zeros_like(z) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", Sum()) 534*da0073e9SAndroid Build Coastguard Worker # Cannot remove a parametrization with several outputs with `leave_parametrized=False` 535*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): 536*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations( 537*da0073e9SAndroid Build Coastguard Worker module, "weight", leave_parametrized=False 538*da0073e9SAndroid Build Coastguard Worker ) 539*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker # A parametrization with an incorrect number of outputs 542*da0073e9SAndroid Build Coastguard Worker class WrongNumberParams(nn.Module): 543*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y, z): 544*da0073e9SAndroid Build Coastguard Worker return x + y + z 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, w): 547*da0073e9SAndroid Build Coastguard Worker return w, torch.zeros_like(w) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker # Makes param(*param.right_inverse(X)) fail 550*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "positional argument"): 551*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", WrongNumberParams()) 552*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor] 555*da0073e9SAndroid Build Coastguard Worker class WrongRightInverse(Identity): 556*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, z): 557*da0073e9SAndroid Build Coastguard Worker return None 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker # right_inverse should return a Tensor or a Sequence[Tensor] 560*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"): 561*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", WrongRightInverse()) 562*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker # If it's a sequence, it must to be a sequence of tensors 565*da0073e9SAndroid Build Coastguard Worker class WrongRightInverseSequence(nn.Module): 566*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 567*da0073e9SAndroid Build Coastguard Worker return x 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, z): 570*da0073e9SAndroid Build Coastguard Worker return None, z 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "of the sequence with type"): 573*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization( 574*da0073e9SAndroid Build Coastguard Worker module, "weight", WrongRightInverseSequence() 575*da0073e9SAndroid Build Coastguard Worker ) 576*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker # A parametrization from one tensor to one tensor that changes the dtype 579*da0073e9SAndroid Build Coastguard Worker class ChangeDtypeInverse(nn.Module): 580*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 581*da0073e9SAndroid Build Coastguard Worker return x.float() 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, w): 584*da0073e9SAndroid Build Coastguard Worker return w.bool() 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker # For parametrizations that return one tensor, right_inverse may not change the dtype 587*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 588*da0073e9SAndroid Build Coastguard Worker ValueError, "outputs one tensor, it may not change the dtype" 589*da0073e9SAndroid Build Coastguard Worker ): 590*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) 591*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker # Doesn't return a tensor 594*da0073e9SAndroid Build Coastguard Worker class NotTensor(nn.Module): 595*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 596*da0073e9SAndroid Build Coastguard Worker return 2 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker # Forward must return a tensor 599*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "must return a tensor"): 600*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", NotTensor()) 601*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker # A parametrization from one tensor to one tensor that changes the dtype 604*da0073e9SAndroid Build Coastguard Worker class ChangeDtype(nn.Module): 605*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 606*da0073e9SAndroid Build Coastguard Worker return x.bool() 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker # forward should not change the initial dtype 609*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "may not change the dtype"): 610*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeDtype()) 611*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker # Change shape 614*da0073e9SAndroid Build Coastguard Worker class ChangeShape(nn.Module): 615*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 616*da0073e9SAndroid Build Coastguard Worker return x[:-1] 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker # forward should not change the original shape 619*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "may not change the shape"): 620*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeShape()) 621*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker # Many to one that changes dtype 624*da0073e9SAndroid Build Coastguard Worker class ChangeDtypeMulti(nn.Module): 625*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 626*da0073e9SAndroid Build Coastguard Worker return (x + y).bool() 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, w): 629*da0073e9SAndroid Build Coastguard Worker return w, w + 1 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker # forward should not change the original shape even for parametrizations with many inputs 632*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "may not change the dtype"): 633*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeDtypeMulti()) 634*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker # Returning a sequence of size one, although weird, it's correct 637*da0073e9SAndroid Build Coastguard Worker class SequenceLen1(nn.Module): 638*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 639*da0073e9SAndroid Build Coastguard Worker return x 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, w): 642*da0073e9SAndroid Build Coastguard Worker return (w,) 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", SequenceLen1()) 645*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(module.parametrizations.weight, "original0")) 646*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(module.parametrizations.weight, "original1")) 647*da0073e9SAndroid Build Coastguard Worker _ = module.weight # Does not throw 648*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(module)) 649*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker # None of the operations above should have altered the weight 652*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(module)) 653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.weight, weight_init) 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 656*da0073e9SAndroid Build Coastguard Worker def test_errors_parametrized_tensor_parametrization(self): 657*da0073e9SAndroid Build Coastguard Worker # Test errors when registering a parametrization on a parametrized tensor 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker class Identity(nn.Module): 660*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 661*da0073e9SAndroid Build Coastguard Worker return x 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(3, 4) 664*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", Identity()) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker # Has to return a tensor 667*da0073e9SAndroid Build Coastguard Worker class WrongReturn(nn.Module): 668*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 669*da0073e9SAndroid Build Coastguard Worker return x, x 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "must return a tensor"): 672*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", WrongReturn()) 673*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(module)) 674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module.parametrizations.weight), 1) 675*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker # Cannot change dtype 678*da0073e9SAndroid Build Coastguard Worker class ChangeDtype(nn.Module): 679*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 680*da0073e9SAndroid Build Coastguard Worker return x.bool() 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "may not change the dtype"): 683*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeDtype()) 684*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(module)) 685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module.parametrizations.weight), 1) 686*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker # Cannot change shape 689*da0073e9SAndroid Build Coastguard Worker class ChangeShape(nn.Module): 690*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 691*da0073e9SAndroid Build Coastguard Worker return x[:-1] 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "may not change the shape"): 694*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeShape()) 695*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(module)) 696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module.parametrizations.weight), 1) 697*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker # The following checks are mostly due to bugs in the code of the parametrization 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker # right_inverse has to return a tensor 702*da0073e9SAndroid Build Coastguard Worker class WrongReturnInverse(Identity): 703*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, x): 704*da0073e9SAndroid Build Coastguard Worker return x, x 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"): 707*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", WrongReturnInverse()) 708*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(module)) 709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module.parametrizations.weight), 1) 710*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker # Cannot change dtype 713*da0073e9SAndroid Build Coastguard Worker class ChangeDtypeInverse(Identity): 714*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, x): 715*da0073e9SAndroid Build Coastguard Worker return x.bool() 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "must have the same dtype"): 718*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) 719*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(module)) 720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module.parametrizations.weight), 1) 721*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker # Cannot change shape 724*da0073e9SAndroid Build Coastguard Worker class ChangeShapeInverse(Identity): 725*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, x): 726*da0073e9SAndroid Build Coastguard Worker return x[:-1] 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "must have the same shape"): 729*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", ChangeShapeInverse()) 730*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(module)) 731*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module.parametrizations.weight), 1) 732*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker # FIXME: Rewrite this test using functions not depending on LAPACK 735*da0073e9SAndroid Build Coastguard Worker # and remove the `@skipIfNoLapack` (see #70995) 736*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 737*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 738*da0073e9SAndroid Build Coastguard Worker def test_multiple_inputs_parametrization(self): 739*da0073e9SAndroid Build Coastguard Worker # A parametrization with several outputs 740*da0073e9SAndroid Build Coastguard Worker class RankOne(nn.Module): 741*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 742*da0073e9SAndroid Build Coastguard Worker # Form a rank-1 matrix from a pair of vectors 743*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(-1) @ y.unsqueeze(-2) 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, Y): 746*da0073e9SAndroid Build Coastguard Worker # We project the given matrix onto the rank 1 matrices 747*da0073e9SAndroid Build Coastguard Worker U, S, Vh = torch.linalg.svd(Y, full_matrices=False) 748*da0073e9SAndroid Build Coastguard Worker # S is ordered in a decreasing way. 749*da0073e9SAndroid Build Coastguard Worker s0_sqrt = S[0].sqrt().unsqueeze(-1) 750*da0073e9SAndroid Build Coastguard Worker return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker # Simple parametrisation 753*da0073e9SAndroid Build Coastguard Worker class Double(nn.Module): 754*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 755*da0073e9SAndroid Build Coastguard Worker return 2.0 * x 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, w): 758*da0073e9SAndroid Build Coastguard Worker return 0.5 * w 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(3, 3) 761*da0073e9SAndroid Build Coastguard Worker # Test one parametrization 762*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", RankOne()) 763*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 764*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 765*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 766*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model.parametrizations.weight, "original0")) 767*da0073e9SAndroid Build Coastguard Worker self.assertIn("original0", model.parametrizations.weight._parameters) 768*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model.parametrizations.weight, "original1")) 769*da0073e9SAndroid Build Coastguard Worker self.assertIn("original1", model.parametrizations.weight._parameters) 770*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 771*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 772*da0073e9SAndroid Build Coastguard Worker # Result should be rank 1 773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): 776*da0073e9SAndroid Build Coastguard Worker # Cannot remove a parametrization with multiple inputs and not leave it parametrized 777*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations( 778*da0073e9SAndroid Build Coastguard Worker model, "weight", leave_parametrized=False 779*da0073e9SAndroid Build Coastguard Worker ) 780*da0073e9SAndroid Build Coastguard Worker # Remove parametrization and check consistency 781*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) 782*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 784*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model)) 785*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 786*da0073e9SAndroid Build Coastguard Worker self.assertIn("weight", model._parameters) 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker # Registering parametrizations with one input on top of one with multiple inputs should work 789*da0073e9SAndroid Build Coastguard Worker init_weight = model.weight.clone() 790*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", RankOne()) 791*da0073e9SAndroid Build Coastguard Worker # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix 792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(init_weight, model.weight) 793*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Double()) 794*da0073e9SAndroid Build Coastguard Worker # The matrix now is twice the initial matrix 795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2.0 * init_weight, model.weight) 796*da0073e9SAndroid Build Coastguard Worker # Multiplying by a scalar does not change the rank 797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker # The model has now three parameters 800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(model.parameters())), 3) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker sgd = torch.optim.SGD(model.parameters(), lr=0.1) 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker # Test backward. Should not throw 805*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 806*da0073e9SAndroid Build Coastguard Worker sgd.zero_grad() 807*da0073e9SAndroid Build Coastguard Worker loss = (model.weight.T @ model.bias).sum() 808*da0073e9SAndroid Build Coastguard Worker loss.backward() 809*da0073e9SAndroid Build Coastguard Worker sgd.step() 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker # Same drill as before, removing should work as expected 812*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): 813*da0073e9SAndroid Build Coastguard Worker # Cannot remove a parametrization with multiple inputs and not leave it parametrized 814*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations( 815*da0073e9SAndroid Build Coastguard Worker model, "weight", leave_parametrized=False 816*da0073e9SAndroid Build Coastguard Worker ) 817*da0073e9SAndroid Build Coastguard Worker # Remove parametrization and check consistency 818*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) 819*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 821*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model)) 822*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 823*da0073e9SAndroid Build Coastguard Worker self.assertIn("weight", model._parameters) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker # The model has now two parameters 826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(model.parameters())), 2) 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker # Test backward. Should not throw 829*da0073e9SAndroid Build Coastguard Worker sgd = torch.optim.SGD(model.parameters(), lr=0.1) 830*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 831*da0073e9SAndroid Build Coastguard Worker sgd.zero_grad() 832*da0073e9SAndroid Build Coastguard Worker loss = (model.weight.T @ model.bias).sum() 833*da0073e9SAndroid Build Coastguard Worker loss.backward() 834*da0073e9SAndroid Build Coastguard Worker sgd.step() 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker # FIXME: Rewrite this test using functions not depending on LAPACK 837*da0073e9SAndroid Build Coastguard Worker # and remove the `@skipIfNoLapack` (see #70995) 838*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 839*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 840*da0073e9SAndroid Build Coastguard Worker def test_caching_parametrization(self): 841*da0073e9SAndroid Build Coastguard Worker r"""Test the caching system of a parametrization""" 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker # Define a couple matrix parametrizations 844*da0073e9SAndroid Build Coastguard Worker class Skew(nn.Module): 845*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 846*da0073e9SAndroid Build Coastguard Worker X = X.tril(-1) 847*da0073e9SAndroid Build Coastguard Worker return X - X.T 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker class Orthogonal(nn.Module): 850*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 851*da0073e9SAndroid Build Coastguard Worker Id = torch.eye(X.size(0), device=X.device) 852*da0073e9SAndroid Build Coastguard Worker return torch.linalg.solve(Id + X, Id - X) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(5, 5) 855*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew()) 856*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Orthogonal()) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker # Test that the caching system works 859*da0073e9SAndroid Build Coastguard Worker with parametrize.cached(): 860*da0073e9SAndroid Build Coastguard Worker X = model.weight 861*da0073e9SAndroid Build Coastguard Worker Y = model.weight 862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(X), id(Y)) 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker # FIXME: Rewrite this test using functions not depending on LAPACK 865*da0073e9SAndroid Build Coastguard Worker # and remove the `@skipIfNoLapack` (see #70995) 866*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 867*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 868*da0073e9SAndroid Build Coastguard Worker def test_caching_parametrization_with_transfer_parametrizations_and_params(self): 869*da0073e9SAndroid Build Coastguard Worker r"""Test that transferring parametrizations doesn't cause issues with caching""" 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker class Skew(nn.Module): 872*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 873*da0073e9SAndroid Build Coastguard Worker X = X.tril(-1) 874*da0073e9SAndroid Build Coastguard Worker return X - X.T 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker class Orthogonal(nn.Module): 877*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 878*da0073e9SAndroid Build Coastguard Worker Id = torch.eye(X.size(0), device=X.device) 879*da0073e9SAndroid Build Coastguard Worker return torch.linalg.solve(Id + X, Id - X) 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(5, 5) 882*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Skew()) 883*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Orthogonal()) 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker to_model = nn.Linear(5, 5) 886*da0073e9SAndroid Build Coastguard Worker parametrize.transfer_parametrizations_and_params(model, to_model) 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker with parametrize.cached(): 889*da0073e9SAndroid Build Coastguard Worker X = model.weight 890*da0073e9SAndroid Build Coastguard Worker Y = model.weight 891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(X), id(Y)) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker A = to_model.weight 894*da0073e9SAndroid Build Coastguard Worker B = to_model.weight 895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(A), id(B)) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker # test that the results are distinct objects for each module 898*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(id(A), id(X)) 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 901*da0073e9SAndroid Build Coastguard Worker def test_parametrization_same_training_mode(self): 902*da0073e9SAndroid Build Coastguard Worker r"""Test training mode updated on parametrization registration""" 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker class Identity(nn.Module): 905*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 906*da0073e9SAndroid Build Coastguard Worker return X 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(4, 4) 909*da0073e9SAndroid Build Coastguard Worker module.eval() 910*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", Identity()) 911*da0073e9SAndroid Build Coastguard Worker self.assertFalse(module.parametrizations.weight[0].training) 912*da0073e9SAndroid Build Coastguard Worker module.train() 913*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(module, "weight", Identity().eval()) 914*da0073e9SAndroid Build Coastguard Worker self.assertTrue(module.parametrizations.weight[0].training) 915*da0073e9SAndroid Build Coastguard Worker self.assertTrue(module.parametrizations.weight[1].training) 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 918*da0073e9SAndroid Build Coastguard Worker def test_type_before_parametrizations(self): 919*da0073e9SAndroid Build Coastguard Worker r"""Test that type_before_parametrizations always retrieves original type""" 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker class Identity(nn.Module): 922*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 923*da0073e9SAndroid Build Coastguard Worker return X 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(5, 5) 926*da0073e9SAndroid Build Coastguard Worker original_type = type(model) 927*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 928*da0073e9SAndroid Build Coastguard Worker parametrize.type_before_parametrizations(model) == original_type 929*da0073e9SAndroid Build Coastguard Worker ) 930*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Identity()) 931*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 932*da0073e9SAndroid Build Coastguard Worker parametrize.type_before_parametrizations(model) == original_type 933*da0073e9SAndroid Build Coastguard Worker ) 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 936*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_after_parametrization(self): 937*da0073e9SAndroid Build Coastguard Worker r"""Test that we are able to create a deepcopy of the module when it's parametrized.""" 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker class AddOne(nn.Module): 940*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 941*da0073e9SAndroid Build Coastguard Worker return x + 1.0 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker class ModelWithoutDeepcopy(nn.Module): 944*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 945*da0073e9SAndroid Build Coastguard Worker super().__init__() 946*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter( 947*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True 948*da0073e9SAndroid Build Coastguard Worker ) 949*da0073e9SAndroid Build Coastguard Worker self.bias = nn.Parameter( 950*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.0, 0.0, 0.0, 0.0]), requires_grad=True 951*da0073e9SAndroid Build Coastguard Worker ) 952*da0073e9SAndroid Build Coastguard Worker self.attr = [1.0, 2.0, 3.0, 4.0] 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker class ActualModel(ModelWithoutDeepcopy): 955*da0073e9SAndroid Build Coastguard Worker # Emulate custom implementation of the deepcopying. 956*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo): 957*da0073e9SAndroid Build Coastguard Worker result = self.__new__(self.__class__) 958*da0073e9SAndroid Build Coastguard Worker memo[id(self)] = result 959*da0073e9SAndroid Build Coastguard Worker result.__dict__ = deepcopy(self.__dict__, memo) 960*da0073e9SAndroid Build Coastguard Worker return result 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker def check_deepcopy(m1: nn.Module, m2: nn.Module): 963*da0073e9SAndroid Build Coastguard Worker w1 = m1.parametrizations.weight.original 964*da0073e9SAndroid Build Coastguard Worker w2 = m2.parametrizations.weight.original 965*da0073e9SAndroid Build Coastguard Worker b1 = ( 966*da0073e9SAndroid Build Coastguard Worker m1.parametrizations.bias.original 967*da0073e9SAndroid Build Coastguard Worker if parametrize.is_parametrized(m1, "bias") 968*da0073e9SAndroid Build Coastguard Worker else m1.bias 969*da0073e9SAndroid Build Coastguard Worker ) 970*da0073e9SAndroid Build Coastguard Worker b2 = ( 971*da0073e9SAndroid Build Coastguard Worker m2.parametrizations.bias.original 972*da0073e9SAndroid Build Coastguard Worker if parametrize.is_parametrized(m2, "bias") 973*da0073e9SAndroid Build Coastguard Worker else m2.bias 974*da0073e9SAndroid Build Coastguard Worker ) 975*da0073e9SAndroid Build Coastguard Worker # Weights, biases and attributes should be equal but they must be different objects. 976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys()) 977*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(m1, m2) 978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w1, w2) 979*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(w1, w2) 980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b1, b2) 981*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(b1, b2) 982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.attr, m2.attr) 983*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(m1.attr, m2.attr) 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker for model in (ModelWithoutDeepcopy(), ActualModel()): 986*da0073e9SAndroid Build Coastguard Worker # General check that we are able to create deepcopy. 987*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", AddOne()) 988*da0073e9SAndroid Build Coastguard Worker check_deepcopy(model, deepcopy(model)) 989*da0073e9SAndroid Build Coastguard Worker # Check that this works on models with several parametrized tensors. 990*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", AddOne()) 991*da0073e9SAndroid Build Coastguard Worker check_deepcopy(model, deepcopy(model)) 992*da0073e9SAndroid Build Coastguard Worker # Check that this works on models where tensors have more than one parametrization. 993*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", AddOne()) 994*da0073e9SAndroid Build Coastguard Worker check_deepcopy(model, deepcopy(model)) 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 997*da0073e9SAndroid Build Coastguard Worker def test_transfer_parametrizations_and_params(self): 998*da0073e9SAndroid Build Coastguard Worker r"""Test that all parametrizations and their associated parameters are transferred.""" 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker class AddOne(nn.Module): 1001*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1002*da0073e9SAndroid Build Coastguard Worker return x + 1.0 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker class Double(nn.Module): 1005*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1006*da0073e9SAndroid Build Coastguard Worker return 2.0 * x 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, x): 1009*da0073e9SAndroid Build Coastguard Worker return 0.5 * x 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker class MinusOne(nn.Module): 1012*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1013*da0073e9SAndroid Build Coastguard Worker return x - 1.0 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(5, 5) 1016*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", AddOne()) 1017*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Double()) 1018*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", MinusOne()) 1019*da0073e9SAndroid Build Coastguard Worker hold_weight = model.weight 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker to_model = torch.ao.nn.qat.Linear( 1022*da0073e9SAndroid Build Coastguard Worker 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() 1023*da0073e9SAndroid Build Coastguard Worker ) 1024*da0073e9SAndroid Build Coastguard Worker parametrize.transfer_parametrizations_and_params(model, to_model) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker # checks that final and original value are correct and the to_model is parametrized 1027*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) 1028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, to_model.weight) 1029*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1030*da0073e9SAndroid Build Coastguard Worker model.parametrizations.weight.original, 1031*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.weight.original, 1032*da0073e9SAndroid Build Coastguard Worker ) 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker # check that the transfer didn't affect the original value 1035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hold_weight, model.weight) 1036*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 1037*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 1038*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 1039*da0073e9SAndroid Build Coastguard Worker del hold_weight 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker # testing that changes to one set of parametrizations do not affect the other 1042*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations(to_model, "weight") 1043*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) 1044*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight")) 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker # also test that parameters that don't exist in to_model get transferred 1047*da0073e9SAndroid Build Coastguard Worker model.test_param = Parameter(torch.randn(5, 5)) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(to_model, "test_param")) 1050*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "test_param", Double()) 1051*da0073e9SAndroid Build Coastguard Worker hold_test_param = model.test_param 1052*da0073e9SAndroid Build Coastguard Worker parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker # check that previously missing params got transferred correctly 1055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.test_param, to_model.test_param) 1056*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1057*da0073e9SAndroid Build Coastguard Worker model.parametrizations.test_param.original, 1058*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.test_param.original, 1059*da0073e9SAndroid Build Coastguard Worker ) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker # check that the new transfer didn't change the value for the from_module 1062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hold_test_param, model.test_param) 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1065*da0073e9SAndroid Build Coastguard Worker def test_transfer_parametrizations_and_params_right_inverse(self): 1066*da0073e9SAndroid Build Coastguard Worker r"""Test that all parametrizations and their associated parameters are transferred.""" 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker class Double(nn.Module): 1069*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1070*da0073e9SAndroid Build Coastguard Worker return 2.0 * x 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, x): 1073*da0073e9SAndroid Build Coastguard Worker return 0.5 * x 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(5, 5) 1076*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Double()) 1077*da0073e9SAndroid Build Coastguard Worker hold_weight = model.weight 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker to_model = torch.ao.nn.qat.Linear( 1080*da0073e9SAndroid Build Coastguard Worker 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() 1081*da0073e9SAndroid Build Coastguard Worker ) 1082*da0073e9SAndroid Build Coastguard Worker parametrize.transfer_parametrizations_and_params(model, to_model) 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker # check that transfer occurs successfully 1085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, to_model.weight) 1086*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1087*da0073e9SAndroid Build Coastguard Worker model.parametrizations.weight.original, 1088*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.weight.original, 1089*da0073e9SAndroid Build Coastguard Worker ) 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker # check that transfer doesn't affect the from_model weight 1092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hold_weight, model.weight) 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1095*da0073e9SAndroid Build Coastguard Worker def test_transfer_parametrizations_and_params_single_param(self): 1096*da0073e9SAndroid Build Coastguard Worker r"""Test that all parametrizations and their associated parameters are transferred.""" 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker class AddOne(nn.Module): 1099*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1100*da0073e9SAndroid Build Coastguard Worker return x + 1.0 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker class Double(nn.Module): 1103*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1104*da0073e9SAndroid Build Coastguard Worker return 2.0 * x 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker class MinusOne(nn.Module): 1107*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1108*da0073e9SAndroid Build Coastguard Worker return x - 1.0 1109*da0073e9SAndroid Build Coastguard Worker 1110*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(5, 5, bias=True) 1111*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", AddOne()) 1112*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Double()) 1113*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", MinusOne()) 1114*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", AddOne()) 1115*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", Double()) 1116*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "bias", MinusOne()) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker to_model = torch.ao.nn.qat.Linear( 1119*da0073e9SAndroid Build Coastguard Worker 5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig() 1120*da0073e9SAndroid Build Coastguard Worker ) 1121*da0073e9SAndroid Build Coastguard Worker parametrize.transfer_parametrizations_and_params(model, to_model, "weight") 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker # check that weight and only weight was transferred 1124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, to_model.weight) 1125*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1126*da0073e9SAndroid Build Coastguard Worker model.parametrizations.weight.original, 1127*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.weight.original, 1128*da0073e9SAndroid Build Coastguard Worker ) 1129*da0073e9SAndroid Build Coastguard Worker self.assertTrue("bias" not in to_model.parametrizations) 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker # FIXME: Rewrite this test using functions not depending on LAPACK 1132*da0073e9SAndroid Build Coastguard Worker # and remove the `@skipIfNoLapack` (see #70995) 1133*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 1134*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1135*da0073e9SAndroid Build Coastguard Worker def test_transfer_parametrizations_and_params_many_to_one(self): 1136*da0073e9SAndroid Build Coastguard Worker # A parametrization with several outputs 1137*da0073e9SAndroid Build Coastguard Worker class RankOne(nn.Module): 1138*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 1139*da0073e9SAndroid Build Coastguard Worker # Form a rank-1 matrix from a pair of vectors 1140*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(-1) @ y.unsqueeze(-2) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, Y): 1143*da0073e9SAndroid Build Coastguard Worker # We project the given matrix onto the rank 1 matrices 1144*da0073e9SAndroid Build Coastguard Worker U, S, Vh = torch.linalg.svd(Y, full_matrices=False) 1145*da0073e9SAndroid Build Coastguard Worker # S is ordered in a decreasing way. 1146*da0073e9SAndroid Build Coastguard Worker s0_sqrt = S[0].sqrt().unsqueeze(-1) 1147*da0073e9SAndroid Build Coastguard Worker return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker class Double(nn.Module): 1150*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1151*da0073e9SAndroid Build Coastguard Worker return 2.0 * x 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(3, 3) 1154*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", RankOne()) 1155*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", Double()) 1156*da0073e9SAndroid Build Coastguard Worker hold_weight = model.weight 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker to_model = torch.ao.nn.qat.Linear( 1159*da0073e9SAndroid Build Coastguard Worker 3, 3, qconfig=torch.ao.quantization.get_default_qconfig() 1160*da0073e9SAndroid Build Coastguard Worker ) 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker parametrize.transfer_parametrizations_and_params(model, to_model) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker # checks that final and original value are correct and the to_model is parametrized 1165*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) 1166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, to_model.weight) 1167*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1168*da0073e9SAndroid Build Coastguard Worker model.parametrizations.weight.original0, 1169*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.weight.original0, 1170*da0073e9SAndroid Build Coastguard Worker ) 1171*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1172*da0073e9SAndroid Build Coastguard Worker model.parametrizations.weight.original1, 1173*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.weight.original1, 1174*da0073e9SAndroid Build Coastguard Worker ) 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker # check that the transfer didn't affect the original value 1177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hold_weight, model.weight) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker # testing that changes to one set of parametrizations do not affect the other 1180*da0073e9SAndroid Build Coastguard Worker model.test_param = Parameter(torch.randn(3, 3)) 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(to_model, "test_param")) 1183*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "test_param", RankOne()) 1184*da0073e9SAndroid Build Coastguard Worker hold_test_param = model.test_param 1185*da0073e9SAndroid Build Coastguard Worker parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker # also check that previously missing params got transferred correctly 1188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.test_param, to_model.test_param) 1189*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1190*da0073e9SAndroid Build Coastguard Worker model.parametrizations.test_param.original0, 1191*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.test_param.original0, 1192*da0073e9SAndroid Build Coastguard Worker ) 1193*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1194*da0073e9SAndroid Build Coastguard Worker model.parametrizations.test_param.original1, 1195*da0073e9SAndroid Build Coastguard Worker to_model.parametrizations.test_param.original1, 1196*da0073e9SAndroid Build Coastguard Worker ) 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker # check that the new transfer didn't change the value for the from_module 1199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hold_test_param, model.test_param) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1202*da0073e9SAndroid Build Coastguard Worker def test_new_spectral_norm(self): 1203*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1204*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 5) 1205*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 7) 1206*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.spectral_norm(m) 1207*da0073e9SAndroid Build Coastguard Worker spectral_norm_m = m.parametrizations.weight[0] 1208*da0073e9SAndroid Build Coastguard Worker 1209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)])) 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker # .parametrizations.weight.original should be trainable 1212*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(m.parametrizations.weight, "original")) 1213*da0073e9SAndroid Build Coastguard Worker self.assertTrue("original" in m.parametrizations.weight._parameters) 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker # u should be just a reused buffer 1216*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(spectral_norm_m, "_u")) 1217*da0073e9SAndroid Build Coastguard Worker self.assertTrue("_u" in spectral_norm_m._buffers) 1218*da0073e9SAndroid Build Coastguard Worker self.assertTrue("_v" in spectral_norm_m._buffers) 1219*da0073e9SAndroid Build Coastguard Worker 1220*da0073e9SAndroid Build Coastguard Worker # weight should be a plain attribute, not counted as a buffer or a param 1221*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(m.weight) 1222*da0073e9SAndroid Build Coastguard Worker self.assertFalse("weight" in m._buffers) 1223*da0073e9SAndroid Build Coastguard Worker self.assertFalse("weight" in m._parameters) 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker # it should also be sharing storage as `weight_orig` 1226*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage()) 1227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size()) 1228*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1229*da0073e9SAndroid Build Coastguard Worker m.parametrizations.weight.original.stride(), m.weight.stride() 1230*da0073e9SAndroid Build Coastguard Worker ) 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1233*da0073e9SAndroid Build Coastguard Worker 1234*da0073e9SAndroid Build Coastguard Worker # spectral_norm is the only parametrization 1235*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, "parametrizations")) 1236*da0073e9SAndroid Build Coastguard Worker self.assertTrue("weight" in m._parameters) 1237*da0073e9SAndroid Build Coastguard Worker 1238*da0073e9SAndroid Build Coastguard Worker # We can register spectral_norm multiple times on the same parameter 1239*da0073e9SAndroid Build Coastguard Worker # and on multiple parameters in the same module 1240*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.spectral_norm(m, "weight") 1241*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.spectral_norm(m, "weight") 1242*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.spectral_norm(m, "bias") 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker # If we remove the parametrization on bias, weight is still parametrized 1245*da0073e9SAndroid Build Coastguard Worker # Removing a parametrization runs forward in eval mode if leave_parametrized=True 1246*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrize.remove_parametrizations(m, "bias") 1247*da0073e9SAndroid Build Coastguard Worker self.assertTrue("bias" in m._parameters) 1248*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(m, "parametrizations")) 1249*da0073e9SAndroid Build Coastguard Worker self.assertFalse("weight" in m._parameters) 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1252*da0073e9SAndroid Build Coastguard Worker # Neither weight and bias are parametrized 1253*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, "parametrizations")) 1254*da0073e9SAndroid Build Coastguard Worker self.assertTrue("weight" in m._parameters) 1255*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m)) 1256*da0073e9SAndroid Build Coastguard Worker 1257*da0073e9SAndroid Build Coastguard Worker # test correctness in training/eval modes and cpu/multi-gpu settings 1258*da0073e9SAndroid Build Coastguard Worker for apply_dp in (True, False): 1259*da0073e9SAndroid Build Coastguard Worker if apply_dp: 1260*da0073e9SAndroid Build Coastguard Worker if not TEST_MULTIGPU: 1261*da0073e9SAndroid Build Coastguard Worker continue 1262*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker def maybe_wrap(m): 1265*da0073e9SAndroid Build Coastguard Worker return torch.nn.DataParallel(m, [0, 1]) 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker else: 1268*da0073e9SAndroid Build Coastguard Worker device = torch.device("cpu") 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker def maybe_wrap(m): 1271*da0073e9SAndroid Build Coastguard Worker return m 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker def get_modules(): 1276*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(3, 4).to(device) 1277*da0073e9SAndroid Build Coastguard Worker m.weight.requires_grad_(requires_grad) 1278*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.spectral_norm(m) 1279*da0073e9SAndroid Build Coastguard Worker wrapped_m = maybe_wrap(m) 1280*da0073e9SAndroid Build Coastguard Worker spectral_norm_m = m.parametrizations.weight[0] 1281*da0073e9SAndroid Build Coastguard Worker return m, wrapped_m, spectral_norm_m 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, device=device) 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker m, wrapped_m, spectral_norm_m = get_modules() 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(spectral_norm_m, "_u")) 1288*da0073e9SAndroid Build Coastguard Worker u0 = spectral_norm_m._u.clone() 1289*da0073e9SAndroid Build Coastguard Worker v0 = spectral_norm_m._v.clone() 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker # TEST TRAINING BEHAVIOR 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker # We perform GD first to modify the initial matrix 1294*da0073e9SAndroid Build Coastguard Worker opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1) 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker opt.zero_grad() 1297*da0073e9SAndroid Build Coastguard Worker wrapped_m(input).sum().backward() 1298*da0073e9SAndroid Build Coastguard Worker opt.step() 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker out = wrapped_m(input) 1301*da0073e9SAndroid Build Coastguard Worker if requires_grad: 1302*da0073e9SAndroid Build Coastguard Worker # run forward again and assert that u and v are updated 1303*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(u0, spectral_norm_m._u) 1304*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(v0, spectral_norm_m._v) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker # assert that backprop reaches original weight 1307*da0073e9SAndroid Build Coastguard Worker # can't use gradcheck because the function changes as we 1308*da0073e9SAndroid Build Coastguard Worker # activate through it in training mode 1309*da0073e9SAndroid Build Coastguard Worker if requires_grad: 1310*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1311*da0073e9SAndroid Build Coastguard Worker out.sum(), m.parametrizations.weight.original 1312*da0073e9SAndroid Build Coastguard Worker ) 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker # test backward works with multiple forwards 1315*da0073e9SAndroid Build Coastguard Worker # it uses training mode so we need to reset `u` and `v` vectors 1316*da0073e9SAndroid Build Coastguard Worker # to same value at beginning for finite difference test to pass 1317*da0073e9SAndroid Build Coastguard Worker saved_u = spectral_norm_m._u.clone() 1318*da0073e9SAndroid Build Coastguard Worker saved_v = spectral_norm_m._v.clone() 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker def fn(input): 1321*da0073e9SAndroid Build Coastguard Worker spectral_norm_m._u.data.copy_(saved_u) 1322*da0073e9SAndroid Build Coastguard Worker spectral_norm_m._v.data.copy_(saved_v) 1323*da0073e9SAndroid Build Coastguard Worker out0 = wrapped_m(input) 1324*da0073e9SAndroid Build Coastguard Worker out1 = wrapped_m(input) 1325*da0073e9SAndroid Build Coastguard Worker return out0 + out1 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker # Make sure we can compute gradients wrt to all the parameters in the case 1328*da0073e9SAndroid Build Coastguard Worker # of double forward 1329*da0073e9SAndroid Build Coastguard Worker fn(input.clone().requires_grad_()).sum().backward() 1330*da0073e9SAndroid Build Coastguard Worker gradcheck( 1331*da0073e9SAndroid Build Coastguard Worker fn, (input.clone().requires_grad_(),), check_batched_grad=False 1332*da0073e9SAndroid Build Coastguard Worker ) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker # test removing 1335*da0073e9SAndroid Build Coastguard Worker # spectral norm module needs to be in eval mode if we'd like to 1336*da0073e9SAndroid Build Coastguard Worker # avoid doing another power iteration 1337*da0073e9SAndroid Build Coastguard Worker m, wrapped_m, _ = get_modules() 1338*da0073e9SAndroid Build Coastguard Worker pre_remove_out = wrapped_m(input) 1339*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 1340*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 1341*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 1342*da0073e9SAndroid Build Coastguard Worker pre_remove_out_ref = pre_remove_out.detach() 1343*da0073e9SAndroid Build Coastguard Worker del pre_remove_out 1344*da0073e9SAndroid Build Coastguard Worker else: 1345*da0073e9SAndroid Build Coastguard Worker pre_remove_out_ref = pre_remove_out 1346*da0073e9SAndroid Build Coastguard Worker m.eval() 1347*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped_m(input), pre_remove_out_ref) 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.spectral_norm(m) 1351*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1352*da0073e9SAndroid Build Coastguard Worker pre_remove_out = wrapped_m(input) 1353*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 1354*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 1355*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 1356*da0073e9SAndroid Build Coastguard Worker pre_remove_out_ref = pre_remove_out.detach() 1357*da0073e9SAndroid Build Coastguard Worker del pre_remove_out 1358*da0073e9SAndroid Build Coastguard Worker else: 1359*da0073e9SAndroid Build Coastguard Worker pre_remove_out_ref = pre_remove_out 1360*da0073e9SAndroid Build Coastguard Worker m.eval() 1361*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped_m(input), pre_remove_out_ref) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker # TEST EVAL BEHAVIOR 1365*da0073e9SAndroid Build Coastguard Worker m, wrapped_m, spectral_norm_m = get_modules() 1366*da0073e9SAndroid Build Coastguard Worker wrapped_m(input) 1367*da0073e9SAndroid Build Coastguard Worker last_train_out = wrapped_m(input) 1368*da0073e9SAndroid Build Coastguard Worker last_train_u = spectral_norm_m._u.clone() 1369*da0073e9SAndroid Build Coastguard Worker last_train_v = spectral_norm_m._v.clone() 1370*da0073e9SAndroid Build Coastguard Worker wrapped_m.zero_grad() 1371*da0073e9SAndroid Build Coastguard Worker wrapped_m.eval() 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker eval_out0 = wrapped_m(input) 1374*da0073e9SAndroid Build Coastguard Worker # assert eval gives same result as last training iteration 1375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eval_out0, last_train_out) 1376*da0073e9SAndroid Build Coastguard Worker # assert doing more iteartion in eval don't change things 1377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eval_out0, wrapped_m(input)) 1378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(last_train_u, spectral_norm_m._u) 1379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(last_train_v, spectral_norm_m._v) 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker # FIXME: the code below is flaky when executed with DataParallel 1382*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/13818 1383*da0073e9SAndroid Build Coastguard Worker if apply_dp: 1384*da0073e9SAndroid Build Coastguard Worker continue 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker # test backward works with multiple forwards in mixed training 1387*da0073e9SAndroid Build Coastguard Worker # and eval modes 1388*da0073e9SAndroid Build Coastguard Worker # it uses training mode so we need to reset `u` and `v` vectors 1389*da0073e9SAndroid Build Coastguard Worker # to same value at beginning for finite difference test to pass 1390*da0073e9SAndroid Build Coastguard Worker saved_u = spectral_norm_m._u.clone() 1391*da0073e9SAndroid Build Coastguard Worker saved_v = spectral_norm_m._v.clone() 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker def fn(input): 1394*da0073e9SAndroid Build Coastguard Worker spectral_norm_m._u.data.copy_(saved_u) 1395*da0073e9SAndroid Build Coastguard Worker spectral_norm_m._v.data.copy_(saved_v) 1396*da0073e9SAndroid Build Coastguard Worker wrapped_m.train() 1397*da0073e9SAndroid Build Coastguard Worker out0 = wrapped_m(input) 1398*da0073e9SAndroid Build Coastguard Worker wrapped_m.eval() 1399*da0073e9SAndroid Build Coastguard Worker out1 = wrapped_m(input) 1400*da0073e9SAndroid Build Coastguard Worker wrapped_m.train() 1401*da0073e9SAndroid Build Coastguard Worker out2 = wrapped_m(input) 1402*da0073e9SAndroid Build Coastguard Worker wrapped_m.eval() 1403*da0073e9SAndroid Build Coastguard Worker out3 = wrapped_m(input) 1404*da0073e9SAndroid Build Coastguard Worker return out0 + out1 + out2 + out3 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (input.clone().requires_grad_(),)) 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker # assert that backprop reaches weight_orig in eval 1409*da0073e9SAndroid Build Coastguard Worker if requires_grad: 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker def fn(weight): 1412*da0073e9SAndroid Build Coastguard Worker return wrapped_m(input) 1413*da0073e9SAndroid Build Coastguard Worker 1414*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (m.parametrizations.weight.original,)) 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker def test_register_parametrization_no_grad(self): 1417*da0073e9SAndroid Build Coastguard Worker r"""Test that it is possible to register a parametrization without gradient""" 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker class SplitAndCat(nn.Module): 1420*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, x): 1421*da0073e9SAndroid Build Coastguard Worker # split the tensor in two halfs 1422*da0073e9SAndroid Build Coastguard Worker return torch.split(x, x.shape[1] // 2) 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker def forward(self, x0, x1): 1425*da0073e9SAndroid Build Coastguard Worker return torch.cat([x0, x1]) 1426*da0073e9SAndroid Build Coastguard Worker 1427*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(8, 8) 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker model.weight.requires_grad = False 1430*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", SplitAndCat()) 1431*da0073e9SAndroid Build Coastguard Worker # making sure the parameterized and decomposed Tensors both have requires_grad == False 1432*da0073e9SAndroid Build Coastguard Worker self.assertFalse(model.weight.requires_grad) 1433*da0073e9SAndroid Build Coastguard Worker self.assertFalse(model.parametrizations.weight.original0.requires_grad) 1434*da0073e9SAndroid Build Coastguard Worker self.assertFalse(model.parametrizations.weight.original1.requires_grad) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1437*da0073e9SAndroid Build Coastguard Worker def test_new_spectral_norm_load_state_dict(self): 1438*da0073e9SAndroid Build Coastguard Worker for activate_times in (0, 3): 1439*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3) 1440*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(3, 5) 1441*da0073e9SAndroid Build Coastguard Worker snm = torch.nn.utils.parametrizations.spectral_norm(m) 1442*da0073e9SAndroid Build Coastguard Worker snm.train() 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker for _ in range(activate_times): 1445*da0073e9SAndroid Build Coastguard Worker snm(inp) 1446*da0073e9SAndroid Build Coastguard Worker 1447*da0073e9SAndroid Build Coastguard Worker state_dict = deepcopy(snm.state_dict()) 1448*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1449*da0073e9SAndroid Build Coastguard Worker { 1450*da0073e9SAndroid Build Coastguard Worker "parametrizations.weight.original", 1451*da0073e9SAndroid Build Coastguard Worker "bias", 1452*da0073e9SAndroid Build Coastguard Worker "parametrizations.weight.0._v", 1453*da0073e9SAndroid Build Coastguard Worker "parametrizations.weight.0._u", 1454*da0073e9SAndroid Build Coastguard Worker }, 1455*da0073e9SAndroid Build Coastguard Worker set(state_dict.keys()), 1456*da0073e9SAndroid Build Coastguard Worker ) 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker # test that non-strict loading works 1459*da0073e9SAndroid Build Coastguard Worker non_strict_state_dict = deepcopy(state_dict) 1460*da0073e9SAndroid Build Coastguard Worker non_strict_state_dict["nonsense"] = "nonsense" 1461*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1462*da0073e9SAndroid Build Coastguard Worker RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"' 1463*da0073e9SAndroid Build Coastguard Worker ): 1464*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=True) 1465*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1466*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict["parametrizations.weight.original"] 1467*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1468*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict["parametrizations.weight.0._u"] 1469*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1470*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict["parametrizations.weight.0._v"] 1471*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1472*da0073e9SAndroid Build Coastguard Worker non_strict_state_dict[ 1473*da0073e9SAndroid Build Coastguard Worker "weight" 1474*da0073e9SAndroid Build Coastguard Worker ] = snm.weight.detach().clone() # set W as a buffer 1475*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1476*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict._metadata[ 1477*da0073e9SAndroid Build Coastguard Worker "parametrizations.weight.0" 1478*da0073e9SAndroid Build Coastguard Worker ] # remove metadata info 1479*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1480*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict["weight"] # remove W buffer 1481*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1482*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict["bias"] 1483*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker # normal state_dict 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker # test that re-wrapping does not matter 1488*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight") 1489*da0073e9SAndroid Build Coastguard Worker snm = torch.nn.utils.parametrizations.spectral_norm(m) 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(state_dict) 1492*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1493*da0073e9SAndroid Build Coastguard Worker snm.eval() 1494*da0073e9SAndroid Build Coastguard Worker out0_eval = snm(inp) 1495*da0073e9SAndroid Build Coastguard Worker snm.train() 1496*da0073e9SAndroid Build Coastguard Worker out1_train = snm(inp) 1497*da0073e9SAndroid Build Coastguard Worker out2_train = snm(inp) 1498*da0073e9SAndroid Build Coastguard Worker snm.eval() 1499*da0073e9SAndroid Build Coastguard Worker out3_eval = snm(inp) 1500*da0073e9SAndroid Build Coastguard Worker 1501*da0073e9SAndroid Build Coastguard Worker # test that re-wrapping does not matter 1502*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight") 1503*da0073e9SAndroid Build Coastguard Worker snm = torch.nn.utils.parametrizations.spectral_norm(m) 1504*da0073e9SAndroid Build Coastguard Worker 1505*da0073e9SAndroid Build Coastguard Worker # Test normal loading 1506*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(state_dict) 1507*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1508*da0073e9SAndroid Build Coastguard Worker snm.eval() 1509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out0_eval, snm(inp)) 1510*da0073e9SAndroid Build Coastguard Worker snm.train() 1511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1_train, snm(inp)) 1512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2_train, snm(inp)) 1513*da0073e9SAndroid Build Coastguard Worker snm.eval() 1514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out3_eval, snm(inp)) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1517*da0073e9SAndroid Build Coastguard Worker def test_new_spectral_norm_dim(self): 1518*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3, 10, 12) 1519*da0073e9SAndroid Build Coastguard Worker m = nn.ConvTranspose2d(3, 4, (5, 6)) 1520*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.spectral_norm(m) 1521*da0073e9SAndroid Build Coastguard Worker snm = m.parametrizations.weight[0] 1522*da0073e9SAndroid Build Coastguard Worker # this should not run into incompatible shapes 1523*da0073e9SAndroid Build Coastguard Worker x = m(inp) 1524*da0073e9SAndroid Build Coastguard Worker # check that u refers to the same dimension 1525*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1526*da0073e9SAndroid Build Coastguard Worker snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape 1527*da0073e9SAndroid Build Coastguard Worker ) 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1530*da0073e9SAndroid Build Coastguard Worker def test_new_spectral_norm_forward(self): 1531*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 5) 1532*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 7) 1533*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.spectral_norm(m) 1534*da0073e9SAndroid Build Coastguard Worker snm = m.parametrizations.weight[0] 1535*da0073e9SAndroid Build Coastguard Worker # naive forward 1536*da0073e9SAndroid Build Coastguard Worker _weight = m.parametrizations.weight.original 1537*da0073e9SAndroid Build Coastguard Worker _bias, _v = m.bias, snm._v 1538*da0073e9SAndroid Build Coastguard Worker _weight_mat = _weight.view(_weight.size(0), -1) 1539*da0073e9SAndroid Build Coastguard Worker _u = torch.mv(_weight_mat, _v) 1540*da0073e9SAndroid Build Coastguard Worker _u = F.normalize(_u, dim=0, eps=1e-12) 1541*da0073e9SAndroid Build Coastguard Worker _v = torch.mv(_weight_mat.t(), _u) 1542*da0073e9SAndroid Build Coastguard Worker _v = F.normalize(_v, dim=0, eps=1e-12) 1543*da0073e9SAndroid Build Coastguard Worker _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v)) 1544*da0073e9SAndroid Build Coastguard Worker out_hat = torch.nn.functional.linear(input, _weight, _bias) 1545*da0073e9SAndroid Build Coastguard Worker expect_out = m(input) 1546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect_out, out_hat) 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1549*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test does not work with TorchDynamo") 1550*da0073e9SAndroid Build Coastguard Worker def test_new_spectral_norm_value(self): 1551*da0073e9SAndroid Build Coastguard Worker # a test that the spectral norm (= top singular value) 1552*da0073e9SAndroid Build Coastguard Worker # is in fact properly calculated, using example of a simple diagonal matrix. 1553*da0073e9SAndroid Build Coastguard Worker for dtype in (torch.float, torch.cfloat): 1554*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(2, 2, dtype=dtype) 1555*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1556*da0073e9SAndroid Build Coastguard Worker # set weight to be diagonal 1557*da0073e9SAndroid Build Coastguard Worker x = torch.diagonal(m.weight) 1558*da0073e9SAndroid Build Coastguard Worker m.weight = nn.Parameter(torch.diag(x)) 1559*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.spectral_norm(m) 1560*da0073e9SAndroid Build Coastguard Worker # weights should be rescaled by spectral norm, (i.e., largest diagonal element in norm) 1561*da0073e9SAndroid Build Coastguard Worker expected = torch.diag(x / x.abs().max()) 1562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight.data, expected) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 1565*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1566*da0073e9SAndroid Build Coastguard Worker def test_orthogonal_parametrization(self): 1567*da0073e9SAndroid Build Coastguard Worker # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker def assert_is_orthogonal(X): 1570*da0073e9SAndroid Build Coastguard Worker n, k = X.size(-2), X.size(-1) 1571*da0073e9SAndroid Build Coastguard Worker if n < k: 1572*da0073e9SAndroid Build Coastguard Worker X = X.mT 1573*da0073e9SAndroid Build Coastguard Worker n, k = k, n 1574*da0073e9SAndroid Build Coastguard Worker Id = torch.eye(k, dtype=X.dtype, device=X.device).expand( 1575*da0073e9SAndroid Build Coastguard Worker *(X.size()[:-2]), k, k 1576*da0073e9SAndroid Build Coastguard Worker ) 1577*da0073e9SAndroid Build Coastguard Worker eps = 10 * n * torch.finfo(X.dtype).eps 1578*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.0) 1579*da0073e9SAndroid Build Coastguard Worker 1580*da0073e9SAndroid Build Coastguard Worker def assert_weight_allclose_Q(weight, W): 1581*da0073e9SAndroid Build Coastguard Worker # Test that weight is equal to the Q part of the QR decomposition of W 1582*da0073e9SAndroid Build Coastguard Worker # (or of its transpose if the matrix is wide) 1583*da0073e9SAndroid Build Coastguard Worker wide_matrix = W.size(-2) < W.size(-1) 1584*da0073e9SAndroid Build Coastguard Worker if wide_matrix: 1585*da0073e9SAndroid Build Coastguard Worker W = W.mT 1586*da0073e9SAndroid Build Coastguard Worker Q, R = torch.linalg.qr(W) 1587*da0073e9SAndroid Build Coastguard Worker Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) 1588*da0073e9SAndroid Build Coastguard Worker if wide_matrix: 1589*da0073e9SAndroid Build Coastguard Worker Q = Q.mT 1590*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.0) 1591*da0073e9SAndroid Build Coastguard Worker 1592*da0073e9SAndroid Build Coastguard Worker for shape, dtype, use_linear in product( 1593*da0073e9SAndroid Build Coastguard Worker ((4, 4), (5, 3), (3, 5)), # square/ tall / wide 1594*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.complex64), 1595*da0073e9SAndroid Build Coastguard Worker (True, False), 1596*da0073e9SAndroid Build Coastguard Worker ): 1597*da0073e9SAndroid Build Coastguard Worker # Conv2d does not support complex yet 1598*da0073e9SAndroid Build Coastguard Worker if not use_linear: 1599*da0073e9SAndroid Build Coastguard Worker continue 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker if use_linear: 1602*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, shape[0], dtype=dtype) 1603*da0073e9SAndroid Build Coastguard Worker else: 1604*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker for parametrization, use_trivialization in product( 1607*da0073e9SAndroid Build Coastguard Worker ("matrix_exp", "cayley", "householder"), (False, True) 1608*da0073e9SAndroid Build Coastguard Worker ): 1609*da0073e9SAndroid Build Coastguard Worker # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False 1610*da0073e9SAndroid Build Coastguard Worker # See Note [right_inverse expm cayley] 1611*da0073e9SAndroid Build Coastguard Worker can_initialize = use_trivialization or parametrization == "householder" 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker # We generate them every time to always start with fresh weights 1614*da0073e9SAndroid Build Coastguard Worker if use_linear: 1615*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(*shape, dtype=dtype) 1616*da0073e9SAndroid Build Coastguard Worker else: 1617*da0073e9SAndroid Build Coastguard Worker m = nn.Conv2d(2, 3, shape, dtype=dtype) 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker # We do not support householder for complex inputs 1620*da0073e9SAndroid Build Coastguard Worker # See Note [Householder complex] 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker # When using the swap_tensors path, this is needed so that the autograd 1623*da0073e9SAndroid Build Coastguard Worker # graph is not alive anymore. 1624*da0073e9SAndroid Build Coastguard Worker if get_swap_module_params_on_conversion(): 1625*da0073e9SAndroid Build Coastguard Worker w_init = m.weight.clone().detach() 1626*da0073e9SAndroid Build Coastguard Worker else: 1627*da0073e9SAndroid Build Coastguard Worker w_init = m.weight.clone() 1628*da0073e9SAndroid Build Coastguard Worker if parametrization == "householder" and m.weight.is_complex(): 1629*da0073e9SAndroid Build Coastguard Worker msg = "householder parametrization does not support complex tensors" 1630*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 1631*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.orthogonal( 1632*da0073e9SAndroid Build Coastguard Worker m, 1633*da0073e9SAndroid Build Coastguard Worker "weight", 1634*da0073e9SAndroid Build Coastguard Worker parametrization, 1635*da0073e9SAndroid Build Coastguard Worker use_trivialization=use_trivialization, 1636*da0073e9SAndroid Build Coastguard Worker ) 1637*da0073e9SAndroid Build Coastguard Worker continue 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker wide_matrix = w_init.size(-2) < w_init.size(-1) 1640*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.orthogonal( 1641*da0073e9SAndroid Build Coastguard Worker m, "weight", parametrization, use_trivialization=use_trivialization 1642*da0073e9SAndroid Build Coastguard Worker ) 1643*da0073e9SAndroid Build Coastguard Worker # Forwards works as expected 1644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w_init.shape, m.weight.shape) 1645*da0073e9SAndroid Build Coastguard Worker assert_is_orthogonal(m.weight) 1646*da0073e9SAndroid Build Coastguard Worker if can_initialize: 1647*da0073e9SAndroid Build Coastguard Worker assert_weight_allclose_Q(m.weight, w_init) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker # Intializing with a given orthogonal matrix works 1650*da0073e9SAndroid Build Coastguard Worker X = torch.randn_like(m.weight) 1651*da0073e9SAndroid Build Coastguard Worker if wide_matrix: 1652*da0073e9SAndroid Build Coastguard Worker X = X.mT 1653*da0073e9SAndroid Build Coastguard Worker w_new = torch.linalg.qr(X).Q 1654*da0073e9SAndroid Build Coastguard Worker if wide_matrix: 1655*da0073e9SAndroid Build Coastguard Worker w_new = w_new.mT 1656*da0073e9SAndroid Build Coastguard Worker if can_initialize: 1657*da0073e9SAndroid Build Coastguard Worker m.weight = w_new 1658*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.0) 1659*da0073e9SAndroid Build Coastguard Worker else: 1660*da0073e9SAndroid Build Coastguard Worker msg = ( 1661*da0073e9SAndroid Build Coastguard Worker "assign to the matrix exponential or the Cayley parametrization" 1662*da0073e9SAndroid Build Coastguard Worker ) 1663*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, msg): 1664*da0073e9SAndroid Build Coastguard Worker m.weight = w_new 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix 1667*da0073e9SAndroid Build Coastguard Worker w_new = torch.randn_like(m.weight) 1668*da0073e9SAndroid Build Coastguard Worker if can_initialize: 1669*da0073e9SAndroid Build Coastguard Worker m.weight = w_new 1670*da0073e9SAndroid Build Coastguard Worker assert_weight_allclose_Q(m.weight, w_new) 1671*da0073e9SAndroid Build Coastguard Worker else: 1672*da0073e9SAndroid Build Coastguard Worker msg = ( 1673*da0073e9SAndroid Build Coastguard Worker "assign to the matrix exponential or the Cayley parametrization" 1674*da0073e9SAndroid Build Coastguard Worker ) 1675*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, msg): 1676*da0073e9SAndroid Build Coastguard Worker m.weight = w_new 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker opt = torch.optim.SGD(m.parameters(), lr=0.1) 1679*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 1680*da0073e9SAndroid Build Coastguard Worker opt.zero_grad() 1681*da0073e9SAndroid Build Coastguard Worker m(input).norm().backward() 1682*da0073e9SAndroid Build Coastguard Worker grad = m.parametrizations.weight.original.grad 1683*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(grad) 1684*da0073e9SAndroid Build Coastguard Worker # We do not update the upper triangular part of the matrix if tall tril if wide 1685*da0073e9SAndroid Build Coastguard Worker if grad.size(-2) >= grad.size(-1): 1686*da0073e9SAndroid Build Coastguard Worker zeros_grad = grad.triu(1) 1687*da0073e9SAndroid Build Coastguard Worker else: 1688*da0073e9SAndroid Build Coastguard Worker zeros_grad = grad.tril(-1) 1689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad)) 1690*da0073e9SAndroid Build Coastguard Worker # The gradient in the diagonal can only be imaginary because a skew-Hermitian 1691*da0073e9SAndroid Build Coastguard Worker # matrix has imaginary diagonal 1692*da0073e9SAndroid Build Coastguard Worker diag_grad = grad.diagonal(dim1=-2, dim2=-1) 1693*da0073e9SAndroid Build Coastguard Worker if grad.is_complex(): 1694*da0073e9SAndroid Build Coastguard Worker diag_grad = diag_grad.real 1695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(diag_grad, torch.zeros_like(diag_grad)) 1696*da0073e9SAndroid Build Coastguard Worker opt.step() 1697*da0073e9SAndroid Build Coastguard Worker assert_is_orthogonal(m.weight) 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 1700*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1701*da0073e9SAndroid Build Coastguard Worker def test_orthogonal_errors(self): 1702*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(3, 4) 1703*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "has to be one of"): 1704*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo") 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a matrix"): 1707*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.orthogonal(m, "bias") 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.orthogonal(m, "weight") 1710*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "matrices of shape"): 1711*da0073e9SAndroid Build Coastguard Worker m.weight = torch.randn(5, 5) 1712*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1715*da0073e9SAndroid Build Coastguard Worker def test_weight_norm_state_dict_compat(self): 1716*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5) 1717*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m) 1718*da0073e9SAndroid Build Coastguard Worker old_dict = m.state_dict() 1719*da0073e9SAndroid Build Coastguard Worker 1720*da0073e9SAndroid Build Coastguard Worker m2 = nn.Linear(4, 5) 1721*da0073e9SAndroid Build Coastguard Worker m2 = torch.nn.utils.parametrizations.weight_norm(m2) 1722*da0073e9SAndroid Build Coastguard Worker m2.load_state_dict(old_dict) 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 1725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), m2(input)) 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1728*da0073e9SAndroid Build Coastguard Worker def test_weight_norm_pickle(self): 1729*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5) 1730*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.weight_norm(m) 1731*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "state_dict"): 1732*da0073e9SAndroid Build Coastguard Worker pickle.dumps(m) 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1735*da0073e9SAndroid Build Coastguard Worker def test_weight_norm_deepcopy(self): 1736*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5) 1737*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.weight_norm(m) 1738*da0073e9SAndroid Build Coastguard Worker m2 = deepcopy(m) 1739*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 1740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), m2(input)) 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker @swap([True]) 1743*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_parametrization(self): 1744*da0073e9SAndroid Build Coastguard Worker class Subclassify(nn.Module): 1745*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 1746*da0073e9SAndroid Build Coastguard Worker return TwoTensor(X, X) 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Worker class UnSubclassify(nn.Module): 1749*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 1750*da0073e9SAndroid Build Coastguard Worker return X.a 1751*da0073e9SAndroid Build Coastguard Worker 1752*da0073e9SAndroid Build Coastguard Worker class IdentityWithRightInverse(nn.Module): 1753*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 1754*da0073e9SAndroid Build Coastguard Worker return X 1755*da0073e9SAndroid Build Coastguard Worker 1756*da0073e9SAndroid Build Coastguard Worker def right_inverse(self, X): 1757*da0073e9SAndroid Build Coastguard Worker return TwoTensor(X, X) 1758*da0073e9SAndroid Build Coastguard Worker 1759*da0073e9SAndroid Build Coastguard Worker def _check_parametrization( 1760*da0073e9SAndroid Build Coastguard Worker parametrization, 1761*da0073e9SAndroid Build Coastguard Worker type_before_registration, 1762*da0073e9SAndroid Build Coastguard Worker type_after_registration, 1763*da0073e9SAndroid Build Coastguard Worker leave_parametrized=False, 1764*da0073e9SAndroid Build Coastguard Worker type_after_right_inverse=None, 1765*da0073e9SAndroid Build Coastguard Worker ): 1766*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(2, 2) 1767*da0073e9SAndroid Build Coastguard Worker buf = torch.randn(2, 2) 1768*da0073e9SAndroid Build Coastguard Worker model.buf = torch.nn.Buffer(buf) 1769*da0073e9SAndroid Build Coastguard Worker if ( 1770*da0073e9SAndroid Build Coastguard Worker type_before_registration == TwoTensor 1771*da0073e9SAndroid Build Coastguard Worker and type_after_registration == Tensor 1772*da0073e9SAndroid Build Coastguard Worker ): 1773*da0073e9SAndroid Build Coastguard Worker model._apply(lambda t: TwoTensor(t, t)) 1774*da0073e9SAndroid Build Coastguard Worker initial_weight = model.weight.clone().detach() 1775*da0073e9SAndroid Build Coastguard Worker initial_weight_id = id(model.weight) 1776*da0073e9SAndroid Build Coastguard Worker initial_buf = model.buf.clone().detach() 1777*da0073e9SAndroid Build Coastguard Worker initial_buf_id = id(model.buf) 1778*da0073e9SAndroid Build Coastguard Worker type_original_weight = ( 1779*da0073e9SAndroid Build Coastguard Worker type_before_registration 1780*da0073e9SAndroid Build Coastguard Worker if type_after_right_inverse is None 1781*da0073e9SAndroid Build Coastguard Worker else type_after_right_inverse 1782*da0073e9SAndroid Build Coastguard Worker ) 1783*da0073e9SAndroid Build Coastguard Worker type_original_buf = ( 1784*da0073e9SAndroid Build Coastguard Worker Tensor if type_original_weight is nn.Parameter else type_original_weight 1785*da0073e9SAndroid Build Coastguard Worker ) 1786*da0073e9SAndroid Build Coastguard Worker type_after_removal_buf = ( 1787*da0073e9SAndroid Build Coastguard Worker type_after_registration if leave_parametrized else type_original_buf 1788*da0073e9SAndroid Build Coastguard Worker ) 1789*da0073e9SAndroid Build Coastguard Worker if leave_parametrized: 1790*da0073e9SAndroid Build Coastguard Worker if type_after_registration is Tensor: 1791*da0073e9SAndroid Build Coastguard Worker type_after_removal_weight = nn.Parameter 1792*da0073e9SAndroid Build Coastguard Worker else: 1793*da0073e9SAndroid Build Coastguard Worker type_after_removal_weight = type_after_registration 1794*da0073e9SAndroid Build Coastguard Worker else: 1795*da0073e9SAndroid Build Coastguard Worker type_after_removal_weight = type_original_weight 1796*da0073e9SAndroid Build Coastguard Worker 1797*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "weight", parametrization()) 1798*da0073e9SAndroid Build Coastguard Worker parametrize.register_parametrization(model, "buf", parametrization()) 1799*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model, "parametrizations")) 1800*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model)) 1801*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parametrize.is_parametrized(model, "bias")) 1802*da0073e9SAndroid Build Coastguard Worker # checks for weight 1803*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "weight")) 1804*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1805*da0073e9SAndroid Build Coastguard Worker isinstance(model.parametrizations.weight.original, nn.Parameter) 1806*da0073e9SAndroid Build Coastguard Worker ) 1807*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1808*da0073e9SAndroid Build Coastguard Worker type(model.parametrizations.weight.original) is type_original_weight 1809*da0073e9SAndroid Build Coastguard Worker ) 1810*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("weight", model._parameters) 1811*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(model.weight) is type_after_registration) 1812*da0073e9SAndroid Build Coastguard Worker # checks for buf 1813*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parametrize.is_parametrized(model, "buf")) 1814*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 1815*da0073e9SAndroid Build Coastguard Worker isinstance(model.parametrizations.buf.original, nn.Parameter) 1816*da0073e9SAndroid Build Coastguard Worker ) 1817*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1818*da0073e9SAndroid Build Coastguard Worker type(model.parametrizations.buf.original) is type_original_buf 1819*da0073e9SAndroid Build Coastguard Worker ) 1820*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(model.buf) is type_after_registration) 1821*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations( 1822*da0073e9SAndroid Build Coastguard Worker model, "weight", leave_parametrized=leave_parametrized 1823*da0073e9SAndroid Build Coastguard Worker ) 1824*da0073e9SAndroid Build Coastguard Worker parametrize.remove_parametrizations( 1825*da0073e9SAndroid Build Coastguard Worker model, "buf", leave_parametrized=leave_parametrized 1826*da0073e9SAndroid Build Coastguard Worker ) 1827*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(model, "parametrizations")) 1828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.__class__, nn.Linear) 1829*da0073e9SAndroid Build Coastguard Worker # checks for weight 1830*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(model.weight) is type_after_removal_weight) 1831*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(model.weight, nn.Parameter)) 1832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.weight), initial_weight_id) 1833*da0073e9SAndroid Build Coastguard Worker # checks for buf 1834*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(model.buf) is type_after_removal_buf) 1835*da0073e9SAndroid Build Coastguard Worker self.assertFalse(isinstance(model.buf, nn.Parameter)) 1836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(model.buf), initial_buf_id) 1837*da0073e9SAndroid Build Coastguard Worker if not leave_parametrized and type_after_right_inverse is None: 1838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.weight, initial_weight) 1839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.buf, initial_buf) 1840*da0073e9SAndroid Build Coastguard Worker 1841*da0073e9SAndroid Build Coastguard Worker _check_parametrization(Subclassify, nn.Parameter, TwoTensor) 1842*da0073e9SAndroid Build Coastguard Worker _check_parametrization(UnSubclassify, TwoTensor, Tensor) 1843*da0073e9SAndroid Build Coastguard Worker _check_parametrization( 1844*da0073e9SAndroid Build Coastguard Worker IdentityWithRightInverse, 1845*da0073e9SAndroid Build Coastguard Worker nn.Parameter, 1846*da0073e9SAndroid Build Coastguard Worker TwoTensor, 1847*da0073e9SAndroid Build Coastguard Worker type_after_right_inverse=TwoTensor, 1848*da0073e9SAndroid Build Coastguard Worker ) 1849*da0073e9SAndroid Build Coastguard Worker _check_parametrization( 1850*da0073e9SAndroid Build Coastguard Worker Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True 1851*da0073e9SAndroid Build Coastguard Worker ) 1852*da0073e9SAndroid Build Coastguard Worker _check_parametrization( 1853*da0073e9SAndroid Build Coastguard Worker UnSubclassify, TwoTensor, Tensor, leave_parametrized=True 1854*da0073e9SAndroid Build Coastguard Worker ) 1855*da0073e9SAndroid Build Coastguard Worker _check_parametrization( 1856*da0073e9SAndroid Build Coastguard Worker IdentityWithRightInverse, 1857*da0073e9SAndroid Build Coastguard Worker nn.Parameter, 1858*da0073e9SAndroid Build Coastguard Worker TwoTensor, 1859*da0073e9SAndroid Build Coastguard Worker leave_parametrized=True, 1860*da0073e9SAndroid Build Coastguard Worker type_after_right_inverse=TwoTensor, 1861*da0073e9SAndroid Build Coastguard Worker ) 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker 1864*da0073e9SAndroid Build Coastguard Workerclass TestNNParametrizationDevice(NNTestCase): 1865*da0073e9SAndroid Build Coastguard Worker @swap([True, False]) 1866*da0073e9SAndroid Build Coastguard Worker def test_weight_norm_parametrization(self, device): 1867*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.bfloat16]: 1868*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4, dtype=dtype, device=device) 1869*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5, dtype=dtype, device=device) 1870*da0073e9SAndroid Build Coastguard Worker expected_output = m(input) 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker # add weight normalization 1873*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.weight_norm(m) 1874*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1875*da0073e9SAndroid Build Coastguard Worker m.parametrizations.weight.original1.size(), m.weight.size() 1876*da0073e9SAndroid Build Coastguard Worker ) 1877*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1)) 1878*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output) 1879*da0073e9SAndroid Build Coastguard Worker 1880*da0073e9SAndroid Build Coastguard Worker # remove weight norm 1881*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1882*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, "parametrizations")) 1883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output) 1884*da0073e9SAndroid Build Coastguard Worker 1885*da0073e9SAndroid Build Coastguard Worker # test with dim=1 1886*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.weight_norm(m, dim=1) 1887*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1888*da0073e9SAndroid Build Coastguard Worker m.parametrizations.weight.original1.size(), m.weight.size() 1889*da0073e9SAndroid Build Coastguard Worker ) 1890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4)) 1891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output) 1892*da0073e9SAndroid Build Coastguard Worker 1893*da0073e9SAndroid Build Coastguard Worker # test with dim=None 1894*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5, dtype=dtype, device=device) 1895*da0073e9SAndroid Build Coastguard Worker expected_output = m(input) 1896*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.parametrizations.weight_norm(m, dim=None) 1897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output) 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Workeronly_for = ("cpu", "cuda") 1901*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for) 1902*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNNParametrization) 1903*da0073e9SAndroid Build Coastguard Worker 1904*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1905*da0073e9SAndroid Build Coastguard Worker run_tests() 1906