1# Owner(s): ["module: nn"] 2import pickle 3from copy import deepcopy 4from itertools import product 5 6import torch 7import torch.nn as nn 8import torch.nn.functional as F 9import torch.nn.init as init 10import torch.nn.utils.parametrize as parametrize 11from torch import Tensor 12from torch.__future__ import get_swap_module_params_on_conversion 13from torch.nn import Buffer, Parameter 14from torch.testing._internal.common_cuda import TEST_MULTIGPU 15from torch.testing._internal.common_device_type import instantiate_device_type_tests 16from torch.testing._internal.common_nn import NNTestCase 17from torch.testing._internal.common_utils import ( 18 gradcheck, 19 instantiate_parametrized_tests, 20 run_tests, 21 set_default_dtype, 22 skipIfNoLapack, 23 skipIfTorchDynamo, 24 swap, 25 TemporaryFileName, 26) 27from torch.testing._internal.two_tensor import TwoTensor 28 29 30class TestNNParametrization(NNTestCase): 31 _do_cuda_memory_leak_check = True 32 _do_cuda_non_default_stream = True 33 34 # FIXME: Rewrite this test using functions not depending on LAPACK 35 # and remove the `@skipIfNoLapack` (see #70995) 36 # torch/nn/utils/parametrize 37 @skipIfNoLapack 38 @swap([True, False]) 39 def test_register_and_remove_parametrization(self): 40 r"""Test that it is possible to add a few parametrizations 41 on a parameter or a buffer and that removing them restores the initial state 42 It also tests that backpropagating through them works as expected 43 """ 44 45 # Define a couple matrix parametrizations 46 class Skew(nn.Module): 47 def forward(self, X): 48 X = X.tril(-1) 49 return X - X.T 50 51 class Orthogonal(nn.Module): 52 def forward(self, X): 53 # Cayley map 54 # If X is skew-symmetric it returns an orthogonal matrix 55 Id = torch.eye(X.size(0), device=X.device) 56 # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous 57 # and autograd raises a performance warning. 58 # This happens when we remove the parametrization with leave_parametrized=True, 59 # which does a set_ with a non-contiguous tensor while the gradient is contiguous 60 return torch.linalg.solve(Id + X, Id - X).contiguous() 61 62 class Resize(nn.Module): 63 def forward(self, X): 64 return X[[0]] 65 66 class NoResize(nn.Module): 67 def forward(self, X): 68 return X 69 70 # Define a couple vector parametrizations 71 class FirstZero(nn.Module): 72 def forward(self, x): 73 return torch.cat([x.new_zeros(1), x[1:]]) 74 75 class LastZero(nn.Module): 76 def forward(self, x): 77 return torch.cat([x[:-1], x.new_zeros(1)]) 78 79 model = nn.Linear(8, 8) 80 initial_weight_id = id(model.weight) 81 initial_bias_id = id(model.bias) 82 initial_model = deepcopy(model) 83 84 # Test unsafe flag 85 with self.assertRaisesRegex( 86 ValueError, 87 "Registering a parametrization may not change the shape of the tensor", 88 ): 89 parametrize.register_parametrization( 90 model, "weight", Resize() 91 ) # default unsafe = False 92 model(torch.ones(8, 8)) 93 94 # One parametrization with unsafe=True 95 parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) 96 self.assertTrue(hasattr(model, "parametrizations")) 97 self.assertTrue(parametrize.is_parametrized(model)) 98 self.assertTrue(parametrize.is_parametrized(model, "weight")) 99 self.assertFalse(parametrize.is_parametrized(model, "bias")) 100 self.assertNotIn("weight", model._parameters) 101 self.assertTrue(model.weight.shape[0] == 1) 102 parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 103 self.assertFalse(hasattr(model, "parametrizations")) 104 self.assertEqual(model.weight, initial_model.weight) 105 self.assertEqual(id(model.weight), initial_weight_id) 106 self.assertEqual(model.__class__, nn.Linear) 107 108 # Two parametrizations with unsafe=True 109 parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) 110 parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False) 111 self.assertTrue(hasattr(model, "parametrizations")) 112 self.assertTrue(parametrize.is_parametrized(model)) 113 self.assertTrue(parametrize.is_parametrized(model, "weight")) 114 self.assertFalse(parametrize.is_parametrized(model, "bias")) 115 self.assertNotIn("weight", model._parameters) 116 self.assertTrue(model.weight.shape[0] == 1) 117 parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 118 self.assertFalse(hasattr(model, "parametrizations")) 119 self.assertEqual(model.weight, initial_model.weight) 120 self.assertEqual(id(model.weight), initial_weight_id) 121 self.assertEqual(model.__class__, nn.Linear) 122 123 # Test unsafe flag doesn't change expected behavior 124 parametrize.register_parametrization(model, "weight", Skew(), unsafe=True) 125 self.assertTrue(hasattr(model, "parametrizations")) 126 self.assertTrue(parametrize.is_parametrized(model)) 127 self.assertTrue(parametrize.is_parametrized(model, "weight")) 128 self.assertFalse(parametrize.is_parametrized(model, "bias")) 129 self.assertNotIn("weight", model._parameters) 130 # Result should be skew-symmetric 131 A = model.weight 132 self.assertEqual(A, -A.T) 133 if get_swap_module_params_on_conversion(): 134 # When using the swap_tensors path, this is needed so that the autograd 135 # graph is not alive anymore. 136 del A 137 # Remove and check consistency 138 parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 139 self.assertFalse(hasattr(model, "parametrizations")) 140 self.assertEqual(model.weight, initial_model.weight) 141 self.assertEqual(id(model.weight), initial_weight_id) 142 self.assertEqual(model.__class__, nn.Linear) 143 144 # Test one parametrization 145 parametrize.register_parametrization(model, "weight", Skew()) 146 self.assertTrue(hasattr(model, "parametrizations")) 147 self.assertTrue(parametrize.is_parametrized(model)) 148 self.assertTrue(parametrize.is_parametrized(model, "weight")) 149 self.assertFalse(parametrize.is_parametrized(model, "bias")) 150 self.assertNotIn("weight", model._parameters) 151 # Result should be skew-symmetric 152 A = model.weight 153 self.assertEqual(A, -A.T) 154 if get_swap_module_params_on_conversion(): 155 # When using the swap_tensors path, this is needed so that the autograd 156 # graph is not alive anymore. 157 del A 158 # Remove and check consistency 159 parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 160 self.assertFalse(hasattr(model, "parametrizations")) 161 self.assertEqual(model.weight, initial_model.weight) 162 self.assertEqual(id(model.weight), initial_weight_id) 163 self.assertEqual(model.__class__, nn.Linear) 164 165 # Test two parametrizations at the same time and removing them 166 parametrize.register_parametrization(model, "weight", Skew()) 167 parametrize.register_parametrization(model, "weight", Orthogonal()) 168 # Result should be orthogonal 169 X = model.weight 170 Id = torch.eye(X.size(0), device=X.device) 171 self.assertEqual(X.T @ X, Id) 172 if get_swap_module_params_on_conversion(): 173 # When using the swap_tensors path, this is needed so that the autograd 174 # graph is not alive anymore. 175 del X 176 # Structure tests 177 self.assertTrue(hasattr(model, "parametrizations")) 178 self.assertTrue(parametrize.is_parametrized(model)) 179 self.assertTrue(parametrize.is_parametrized(model, "weight")) 180 self.assertFalse(parametrize.is_parametrized(model, "bias")) 181 self.assertIn("weight", model.parametrizations) 182 self.assertNotIn("weight", model._parameters) 183 # Remove 184 parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 185 self.assertEqual(model.weight, initial_model.weight) 186 self.assertEqual(id(model.weight), initial_weight_id) 187 self.assertFalse(hasattr(model, "parametrizations")) 188 self.assertEqual(model.__class__, nn.Linear) 189 190 # Add everything 191 parametrize.register_parametrization(model, "weight", Skew()) 192 parametrize.register_parametrization(model, "weight", Orthogonal()) 193 parametrize.register_parametrization(model, "bias", FirstZero()) 194 parametrize.register_parametrization(model, "bias", LastZero()) 195 196 # Basic tests 197 self.assertTrue(parametrize.is_parametrized(model)) 198 self.assertTrue(parametrize.is_parametrized(model, "weight")) 199 self.assertTrue(parametrize.is_parametrized(model, "bias")) 200 self.assertEqual(model.bias[0].item(), 0.0) 201 self.assertEqual(model.bias[-1].item(), 0.0) 202 self.assertEqual( 203 len(list(model.parameters())), 2 204 ) # Nothing weird has happpened 205 # Should not throw 206 207 sgd = torch.optim.SGD(model.parameters(), lr=0.01) 208 209 weight_copy = model.weight.clone() 210 bias_copy = model.bias.clone() 211 sgd.zero_grad() 212 (model.weight.T @ model.bias).sum().backward() 213 sgd.step() 214 self.assertNotEqual(model.weight, weight_copy) 215 self.assertNotEqual(model.bias, bias_copy) 216 217 # Remove first parametrization. 218 # Check that the model is still parametrized and so is the second parameter 219 parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 220 self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized 221 self.assertFalse( 222 parametrize.is_parametrized(model, "weight") 223 ) # Parametrization removed 224 self.assertTrue( 225 parametrize.is_parametrized(model, "bias") 226 ) # Still parametrized 227 self.assertEqual(model.bias[0].item(), 0.0) # Still parametrized 228 self.assertEqual(model.bias[-1].item(), 0.0) # Still parametrized 229 self.assertNotEqual(model.weight, initial_model.weight) # Has been updated 230 self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id 231 self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened 232 # Should not throw 233 weight_copy = model.weight.clone() 234 bias_copy = model.bias.clone() 235 sgd.zero_grad() 236 (model.weight.T @ model.bias).sum().backward() 237 sgd.step() 238 self.assertNotEqual(model.weight, weight_copy) 239 self.assertNotEqual(model.bias, bias_copy) 240 241 # Remove the second parametrization. 242 # Check that the module is not parametrized 243 parametrize.remove_parametrizations(model, "bias", leave_parametrized=False) 244 self.assertFalse(parametrize.is_parametrized(model)) # Not parametrized 245 self.assertNotEqual(model.bias, initial_model.bias) # Has been updated 246 self.assertNotEqual(model.bias[0].item(), 0.0) # Not parametrized 247 self.assertNotEqual(model.bias[-1].item(), 0.0) # Not parametrized 248 self.assertEqual(id(model.bias), initial_bias_id) # Keeps the same id 249 self.assertFalse( 250 hasattr(model, "parametrizations") 251 ) # Not parametrized the module 252 self.assertEqual(model.__class__, nn.Linear) # Resores the previous class 253 self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happeed 254 255 # Should not throw things are updated 256 weight_copy = model.weight.clone() 257 bias_copy = model.bias.clone() 258 sgd.zero_grad() 259 (model.weight.T @ model.bias).sum().backward() 260 sgd.step() 261 self.assertNotEqual(model.weight, weight_copy) 262 self.assertNotEqual(model.bias, bias_copy) 263 if get_swap_module_params_on_conversion(): 264 # When using the swap_tensors path, this is needed so that the autograd 265 # graph is not alive anymore. 266 del weight_copy, bias_copy 267 268 # Test leave_parametrized=True 269 for _ in range(2): 270 parametrize.register_parametrization(model, "weight", Skew()) 271 parametrize.register_parametrization(model, "weight", Orthogonal()) 272 parametrize.remove_parametrizations( 273 model, "weight", leave_parametrized=True 274 ) 275 # We didn't change the dtype nor had multiple inputs, so the id should be the same 276 self.assertEqual(id(model.weight), initial_weight_id) 277 self.assertEqual(id(model.bias), initial_bias_id) 278 279 # Should not throw. Things are updated 280 weight_copy = model.weight.clone() 281 bias_copy = model.bias.clone() 282 sgd.zero_grad() 283 (model.weight.T @ model.bias).sum().backward() 284 sgd.step() 285 self.assertNotEqual(model.weight, weight_copy) 286 self.assertNotEqual(model.bias, bias_copy) 287 if get_swap_module_params_on_conversion(): 288 # When using the swap_tensors path, this is needed so that the autograd 289 # graph is not alive anymore. 290 del weight_copy, bias_copy 291 292 @swap([True, False]) 293 def test_register_and_remove_nested_parametrization(self): 294 r"""Test that it is possible to nest the parametrizations 295 meaning that the original param is parametrized again 296 """ 297 298 class Skew(nn.Module): 299 def forward(self, X): 300 X = X.tril(-1) 301 return X - X.T 302 303 model = nn.Linear(8, 8) 304 # Add top level parametrization 305 parametrize.register_parametrization(model, "weight", Skew()) 306 self.assertTrue(hasattr(model, "parametrizations")) 307 self.assertTrue(parametrize.is_parametrized(model)) 308 self.assertTrue(parametrize.is_parametrized(model, "weight")) 309 self.assertFalse(parametrize.is_parametrized(model, "bias")) 310 self.assertNotIn("weight", model._parameters) 311 # Result should be skew-symmetric 312 A = model.weight 313 self.assertEqual(A, -A.T) 314 if get_swap_module_params_on_conversion(): 315 # When using the swap_tensors path, this is needed so that the autograd 316 # graph is not alive anymore. 317 del A 318 319 # Add nested parametrization 320 param_mod = model.parametrizations.weight 321 self.assertFalse(hasattr(param_mod, "parametrizations")) 322 self.assertFalse(parametrize.is_parametrized(param_mod)) 323 self.assertFalse(parametrize.is_parametrized(param_mod, "original")) 324 325 parametrize.register_parametrization(param_mod, "original", Skew()) 326 self.assertTrue(hasattr(param_mod, "parametrizations")) 327 self.assertTrue(parametrize.is_parametrized(param_mod)) 328 self.assertTrue(parametrize.is_parametrized(param_mod, "original")) 329 self.assertNotIn("original", param_mod._parameters) 330 # Result should be skew-symmetric 331 A = param_mod.original 332 self.assertEqual(A, -A.T) 333 334 # Remove nested param and check consistency 335 parametrize.remove_parametrizations( 336 param_mod, "original", leave_parametrized=False 337 ) 338 self.assertFalse(hasattr(param_mod, "parametrizations")) 339 self.assertEqual(param_mod.__class__, parametrize.ParametrizationList) 340 341 # Remove top level and check consistency 342 parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) 343 self.assertFalse(hasattr(model, "parametrizations")) 344 self.assertEqual(model.__class__, nn.Linear) 345 346 @swap([True, False]) 347 def test_register_and_remove_buffer_parametrization(self): 348 r"""Test that it is possible to add and remove parametrizations on buffers""" 349 350 # Define a couple vector parametrizations 351 class FirstZero(nn.Module): 352 def forward(self, x): 353 return torch.cat([x.new_zeros(1), x[1:]]) 354 355 class LastZero(nn.Module): 356 def forward(self, x): 357 return torch.cat([x[:-1], x.new_zeros(1)]) 358 359 model = nn.Linear(8, 8) 360 361 # Instantiate parametrizations on buffers. It should work as expected 362 delattr(model, "bias") 363 model.bias = Buffer(torch.ones(8)) 364 parametrize.register_parametrization(model, "bias", FirstZero()) 365 parametrize.register_parametrization(model, "bias", LastZero()) 366 self.assertTrue(parametrize.is_parametrized(model)) 367 self.assertTrue(parametrize.is_parametrized(model, "bias")) 368 self.assertEqual(model.bias[0].item(), 0.0) 369 self.assertEqual(model.bias[-1].item(), 0.0) 370 self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) 371 self.assertEqual(len(list(model.parameters())), 1) 372 373 # Remove parametrizations on buffers. It should work as expected 374 parametrize.remove_parametrizations(model, "bias", leave_parametrized=True) 375 self.assertFalse(parametrize.is_parametrized(model)) 376 self.assertFalse(parametrize.is_parametrized(model, "bias")) 377 self.assertEqual(model.bias[0].item(), 0.0) 378 self.assertEqual(model.bias[-1].item(), 0.0) 379 self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) 380 self.assertEqual(len(list(model.parameters())), 1) 381 382 # FIXME: Rewrite this test using functions not depending on LAPACK 383 # and remove the `@skipIfNoLapack` (see #70995) 384 @skipIfNoLapack 385 @swap([True, False]) 386 def test_serialization_parametrization(self): 387 r"""Test that it is possible to serialize a parametrized model via state_dict""" 388 389 # A stateful parametrization 390 class Orthogonal(nn.Module): 391 def __init__(self, n): 392 super().__init__() 393 self.id = Buffer(torch.eye(n)) 394 self.B = Buffer(torch.empty(n, n)) 395 init.orthogonal_(self.B) 396 397 def forward(self, X): 398 A = X.triu(1) 399 A = A - A.T 400 return self.B @ torch.linalg.solve(self.id + A, self.id - A) 401 402 def get_model(): 403 model = torch.nn.Sequential( 404 torch.nn.Linear(5, 5), 405 torch.nn.ReLU(), 406 torch.nn.Linear(5, 1), 407 ) 408 409 parametrize.register_parametrization(model[0], "weight", Orthogonal(5)) 410 return model 411 412 model = get_model() 413 414 prev_weight = model[0].weight 415 prev_B = model[0].parametrizations.weight[0].B 416 417 new_model = get_model() 418 with TemporaryFileName() as fname: 419 torch.save(model.state_dict(), fname) 420 new_model.load_state_dict(torch.load(fname)) 421 422 # Integrity tests 423 self.assertTrue(parametrize.is_parametrized(new_model[0], "weight")) 424 self.assertEqual(prev_weight, new_model[0].weight) 425 self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B) 426 427 # Trying to save the whole parametrized model raises 428 with self.assertRaisesRegex(RuntimeError, "state_dict"): 429 with TemporaryFileName() as fname: 430 torch.save(model, fname) 431 432 # FIXME: Rewrite this test using functions not depending on LAPACK 433 # and remove the `@skipIfNoLapack` (see #70995) 434 @skipIfNoLapack 435 @swap([True, False]) 436 def test_initialization_parametrization(self): 437 r"""Test that it is possible to initialize a parametrization when it 438 implements a `right_inverse` method 439 """ 440 441 class Skew(nn.Module): 442 def forward(self, X): 443 A = X.triu(1) 444 return A - A.T 445 446 def is_skew(self, A): 447 return torch.allclose(A, -A.T, atol=1e-6) 448 449 def right_inverse(self, X): 450 if not self.is_skew(X): 451 raise ValueError("The matrix is not skew-symmetric.") 452 return X.triu(1) 453 454 # Implements a Cayley map where right_inverse is not quite the inverse of forward 455 class Orthogonal(nn.Module): 456 def __init__(self, n): 457 super().__init__() 458 self.B = Buffer(torch.eye(n)) 459 460 def forward(self, X): 461 Id = torch.eye(X.size(0)) 462 return self.B @ torch.linalg.solve(Id + X, Id - X) 463 464 def is_orthogonal(self, X): 465 Id = torch.eye(X.size(0)) 466 return torch.allclose(X.T @ X, Id, atol=1e-4) 467 468 def right_inverse(self, X): 469 if not self.is_orthogonal(X): 470 raise ValueError("The input is not orthogonal.") 471 # cayley(0) == Id, so B @ cayley(0) == B 472 self.B = X 473 return torch.zeros_like(X) 474 475 N = 5 476 model = nn.Linear(N, N) 477 # Register the skew-symmetric constraint. The result is now skew-symmetric 478 skew = Skew() 479 # Make the weight skew-symmetric before registering the parametrization 480 with torch.no_grad(): 481 model.weight.set_(skew(model.weight)) 482 parametrize.register_parametrization(model, "weight", skew) 483 X = torch.rand(N, N) 484 # X is not skew-symmetric, so it throws an error 485 with self.assertRaises(ValueError): 486 model.weight = X 487 # Make X skew-symmetric 488 X = X - X.T 489 model.weight = X 490 self.assertEqual(model.parametrizations.weight.original, X.triu(1)) 491 self.assertEqual(model.weight, X) 492 493 # Having several parametrizations registered should work in the same way 494 parametrize.register_parametrization(model, "weight", Orthogonal(N)) 495 # Register now the Cayley map. The result is now orthogonal 496 X = torch.rand(N, N) 497 # X is not orthogonal, so it throws an error 498 with self.assertRaises(ValueError): 499 model.weight = X 500 init.orthogonal_(X) 501 model.weight = X 502 self.assertEqual(model.weight, X) 503 self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) 504 505 @swap([True, False]) 506 def test_errors_unparametrized_tensor_parametrization(self): 507 # Test errors when registering a parametrization on an unparametrized tensor 508 module = nn.Linear(3, 4) 509 weight_init = module.weight.clone() 510 511 class Identity(nn.Module): 512 def forward(self, x): 513 return x 514 515 # Register a parametrization on a non-existing parameter throws 516 with self.assertRaisesRegex(ValueError, "does not have a parameter"): 517 parametrize.register_parametrization(module, "foo", Identity()) 518 self.assertFalse(parametrize.is_parametrized(module)) 519 520 # Removing parametrizations from an unparametrized tensor throws 521 with self.assertRaisesRegex(ValueError, "does not have a parametrization"): 522 parametrize.remove_parametrizations(module, "bias") 523 self.assertFalse(parametrize.is_parametrized(module)) 524 525 # A correct parametrization with several outputs 526 class Sum(nn.Module): 527 def forward(self, x, y): 528 return x + y 529 530 def right_inverse(self, z): 531 return z, torch.zeros_like(z) 532 533 parametrize.register_parametrization(module, "weight", Sum()) 534 # Cannot remove a parametrization with several outputs with `leave_parametrized=False` 535 with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): 536 parametrize.remove_parametrizations( 537 module, "weight", leave_parametrized=False 538 ) 539 parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) 540 541 # A parametrization with an incorrect number of outputs 542 class WrongNumberParams(nn.Module): 543 def forward(self, x, y, z): 544 return x + y + z 545 546 def right_inverse(self, w): 547 return w, torch.zeros_like(w) 548 549 # Makes param(*param.right_inverse(X)) fail 550 with self.assertRaisesRegex(TypeError, "positional argument"): 551 parametrize.register_parametrization(module, "weight", WrongNumberParams()) 552 self.assertFalse(parametrize.is_parametrized(module)) 553 554 # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor] 555 class WrongRightInverse(Identity): 556 def right_inverse(self, z): 557 return None 558 559 # right_inverse should return a Tensor or a Sequence[Tensor] 560 with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"): 561 parametrize.register_parametrization(module, "weight", WrongRightInverse()) 562 self.assertFalse(parametrize.is_parametrized(module)) 563 564 # If it's a sequence, it must to be a sequence of tensors 565 class WrongRightInverseSequence(nn.Module): 566 def forward(self, x, y): 567 return x 568 569 def right_inverse(self, z): 570 return None, z 571 572 with self.assertRaisesRegex(ValueError, "of the sequence with type"): 573 parametrize.register_parametrization( 574 module, "weight", WrongRightInverseSequence() 575 ) 576 self.assertFalse(parametrize.is_parametrized(module)) 577 578 # A parametrization from one tensor to one tensor that changes the dtype 579 class ChangeDtypeInverse(nn.Module): 580 def forward(self, x): 581 return x.float() 582 583 def right_inverse(self, w): 584 return w.bool() 585 586 # For parametrizations that return one tensor, right_inverse may not change the dtype 587 with self.assertRaisesRegex( 588 ValueError, "outputs one tensor, it may not change the dtype" 589 ): 590 parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) 591 self.assertFalse(parametrize.is_parametrized(module)) 592 593 # Doesn't return a tensor 594 class NotTensor(nn.Module): 595 def forward(self, x): 596 return 2 597 598 # Forward must return a tensor 599 with self.assertRaisesRegex(ValueError, "must return a tensor"): 600 parametrize.register_parametrization(module, "weight", NotTensor()) 601 self.assertFalse(parametrize.is_parametrized(module)) 602 603 # A parametrization from one tensor to one tensor that changes the dtype 604 class ChangeDtype(nn.Module): 605 def forward(self, x): 606 return x.bool() 607 608 # forward should not change the initial dtype 609 with self.assertRaisesRegex(ValueError, "may not change the dtype"): 610 parametrize.register_parametrization(module, "weight", ChangeDtype()) 611 self.assertFalse(parametrize.is_parametrized(module)) 612 613 # Change shape 614 class ChangeShape(nn.Module): 615 def forward(self, x): 616 return x[:-1] 617 618 # forward should not change the original shape 619 with self.assertRaisesRegex(ValueError, "may not change the shape"): 620 parametrize.register_parametrization(module, "weight", ChangeShape()) 621 self.assertFalse(parametrize.is_parametrized(module)) 622 623 # Many to one that changes dtype 624 class ChangeDtypeMulti(nn.Module): 625 def forward(self, x, y): 626 return (x + y).bool() 627 628 def right_inverse(self, w): 629 return w, w + 1 630 631 # forward should not change the original shape even for parametrizations with many inputs 632 with self.assertRaisesRegex(ValueError, "may not change the dtype"): 633 parametrize.register_parametrization(module, "weight", ChangeDtypeMulti()) 634 self.assertFalse(parametrize.is_parametrized(module)) 635 636 # Returning a sequence of size one, although weird, it's correct 637 class SequenceLen1(nn.Module): 638 def forward(self, x): 639 return x 640 641 def right_inverse(self, w): 642 return (w,) 643 644 parametrize.register_parametrization(module, "weight", SequenceLen1()) 645 self.assertTrue(hasattr(module.parametrizations.weight, "original0")) 646 self.assertFalse(hasattr(module.parametrizations.weight, "original1")) 647 _ = module.weight # Does not throw 648 self.assertTrue(parametrize.is_parametrized(module)) 649 parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) 650 651 # None of the operations above should have altered the weight 652 self.assertFalse(parametrize.is_parametrized(module)) 653 self.assertEqual(module.weight, weight_init) 654 655 @swap([True, False]) 656 def test_errors_parametrized_tensor_parametrization(self): 657 # Test errors when registering a parametrization on a parametrized tensor 658 659 class Identity(nn.Module): 660 def forward(self, x): 661 return x 662 663 module = nn.Linear(3, 4) 664 parametrize.register_parametrization(module, "weight", Identity()) 665 666 # Has to return a tensor 667 class WrongReturn(nn.Module): 668 def forward(self, x): 669 return x, x 670 671 with self.assertRaisesRegex(ValueError, "must return a tensor"): 672 parametrize.register_parametrization(module, "weight", WrongReturn()) 673 self.assertTrue(parametrize.is_parametrized(module)) 674 self.assertEqual(len(module.parametrizations.weight), 1) 675 self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 676 677 # Cannot change dtype 678 class ChangeDtype(nn.Module): 679 def forward(self, x): 680 return x.bool() 681 682 with self.assertRaisesRegex(ValueError, "may not change the dtype"): 683 parametrize.register_parametrization(module, "weight", ChangeDtype()) 684 self.assertTrue(parametrize.is_parametrized(module)) 685 self.assertEqual(len(module.parametrizations.weight), 1) 686 self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 687 688 # Cannot change shape 689 class ChangeShape(nn.Module): 690 def forward(self, x): 691 return x[:-1] 692 693 with self.assertRaisesRegex(ValueError, "may not change the shape"): 694 parametrize.register_parametrization(module, "weight", ChangeShape()) 695 self.assertTrue(parametrize.is_parametrized(module)) 696 self.assertEqual(len(module.parametrizations.weight), 1) 697 self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 698 699 # The following checks are mostly due to bugs in the code of the parametrization 700 701 # right_inverse has to return a tensor 702 class WrongReturnInverse(Identity): 703 def right_inverse(self, x): 704 return x, x 705 706 with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"): 707 parametrize.register_parametrization(module, "weight", WrongReturnInverse()) 708 self.assertTrue(parametrize.is_parametrized(module)) 709 self.assertEqual(len(module.parametrizations.weight), 1) 710 self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 711 712 # Cannot change dtype 713 class ChangeDtypeInverse(Identity): 714 def right_inverse(self, x): 715 return x.bool() 716 717 with self.assertRaisesRegex(ValueError, "must have the same dtype"): 718 parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) 719 self.assertTrue(parametrize.is_parametrized(module)) 720 self.assertEqual(len(module.parametrizations.weight), 1) 721 self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 722 723 # Cannot change shape 724 class ChangeShapeInverse(Identity): 725 def right_inverse(self, x): 726 return x[:-1] 727 728 with self.assertRaisesRegex(ValueError, "must have the same shape"): 729 parametrize.register_parametrization(module, "weight", ChangeShapeInverse()) 730 self.assertTrue(parametrize.is_parametrized(module)) 731 self.assertEqual(len(module.parametrizations.weight), 1) 732 self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) 733 734 # FIXME: Rewrite this test using functions not depending on LAPACK 735 # and remove the `@skipIfNoLapack` (see #70995) 736 @skipIfNoLapack 737 @swap([True, False]) 738 def test_multiple_inputs_parametrization(self): 739 # A parametrization with several outputs 740 class RankOne(nn.Module): 741 def forward(self, x, y): 742 # Form a rank-1 matrix from a pair of vectors 743 return x.unsqueeze(-1) @ y.unsqueeze(-2) 744 745 def right_inverse(self, Y): 746 # We project the given matrix onto the rank 1 matrices 747 U, S, Vh = torch.linalg.svd(Y, full_matrices=False) 748 # S is ordered in a decreasing way. 749 s0_sqrt = S[0].sqrt().unsqueeze(-1) 750 return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt 751 752 # Simple parametrisation 753 class Double(nn.Module): 754 def forward(self, x): 755 return 2.0 * x 756 757 def right_inverse(self, w): 758 return 0.5 * w 759 760 model = nn.Linear(3, 3) 761 # Test one parametrization 762 parametrize.register_parametrization(model, "weight", RankOne()) 763 self.assertTrue(hasattr(model, "parametrizations")) 764 self.assertTrue(parametrize.is_parametrized(model)) 765 self.assertTrue(parametrize.is_parametrized(model, "weight")) 766 self.assertTrue(hasattr(model.parametrizations.weight, "original0")) 767 self.assertIn("original0", model.parametrizations.weight._parameters) 768 self.assertTrue(hasattr(model.parametrizations.weight, "original1")) 769 self.assertIn("original1", model.parametrizations.weight._parameters) 770 self.assertFalse(parametrize.is_parametrized(model, "bias")) 771 self.assertNotIn("weight", model._parameters) 772 # Result should be rank 1 773 self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 774 775 with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): 776 # Cannot remove a parametrization with multiple inputs and not leave it parametrized 777 parametrize.remove_parametrizations( 778 model, "weight", leave_parametrized=False 779 ) 780 # Remove parametrization and check consistency 781 parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) 782 self.assertFalse(hasattr(model, "parametrizations")) 783 self.assertEqual(model.__class__, nn.Linear) 784 self.assertFalse(parametrize.is_parametrized(model)) 785 self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 786 self.assertIn("weight", model._parameters) 787 788 # Registering parametrizations with one input on top of one with multiple inputs should work 789 init_weight = model.weight.clone() 790 parametrize.register_parametrization(model, "weight", RankOne()) 791 # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix 792 self.assertEqual(init_weight, model.weight) 793 parametrize.register_parametrization(model, "weight", Double()) 794 # The matrix now is twice the initial matrix 795 self.assertEqual(2.0 * init_weight, model.weight) 796 # Multiplying by a scalar does not change the rank 797 self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 798 799 # The model has now three parameters 800 self.assertEqual(len(list(model.parameters())), 3) 801 802 sgd = torch.optim.SGD(model.parameters(), lr=0.1) 803 804 # Test backward. Should not throw 805 for _ in range(2): 806 sgd.zero_grad() 807 loss = (model.weight.T @ model.bias).sum() 808 loss.backward() 809 sgd.step() 810 811 # Same drill as before, removing should work as expected 812 with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): 813 # Cannot remove a parametrization with multiple inputs and not leave it parametrized 814 parametrize.remove_parametrizations( 815 model, "weight", leave_parametrized=False 816 ) 817 # Remove parametrization and check consistency 818 parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) 819 self.assertFalse(hasattr(model, "parametrizations")) 820 self.assertEqual(model.__class__, nn.Linear) 821 self.assertFalse(parametrize.is_parametrized(model)) 822 self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) 823 self.assertIn("weight", model._parameters) 824 825 # The model has now two parameters 826 self.assertEqual(len(list(model.parameters())), 2) 827 828 # Test backward. Should not throw 829 sgd = torch.optim.SGD(model.parameters(), lr=0.1) 830 for _ in range(2): 831 sgd.zero_grad() 832 loss = (model.weight.T @ model.bias).sum() 833 loss.backward() 834 sgd.step() 835 836 # FIXME: Rewrite this test using functions not depending on LAPACK 837 # and remove the `@skipIfNoLapack` (see #70995) 838 @skipIfNoLapack 839 @swap([True, False]) 840 def test_caching_parametrization(self): 841 r"""Test the caching system of a parametrization""" 842 843 # Define a couple matrix parametrizations 844 class Skew(nn.Module): 845 def forward(self, X): 846 X = X.tril(-1) 847 return X - X.T 848 849 class Orthogonal(nn.Module): 850 def forward(self, X): 851 Id = torch.eye(X.size(0), device=X.device) 852 return torch.linalg.solve(Id + X, Id - X) 853 854 model = nn.Linear(5, 5) 855 parametrize.register_parametrization(model, "weight", Skew()) 856 parametrize.register_parametrization(model, "weight", Orthogonal()) 857 858 # Test that the caching system works 859 with parametrize.cached(): 860 X = model.weight 861 Y = model.weight 862 self.assertEqual(id(X), id(Y)) 863 864 # FIXME: Rewrite this test using functions not depending on LAPACK 865 # and remove the `@skipIfNoLapack` (see #70995) 866 @skipIfNoLapack 867 @swap([True, False]) 868 def test_caching_parametrization_with_transfer_parametrizations_and_params(self): 869 r"""Test that transferring parametrizations doesn't cause issues with caching""" 870 871 class Skew(nn.Module): 872 def forward(self, X): 873 X = X.tril(-1) 874 return X - X.T 875 876 class Orthogonal(nn.Module): 877 def forward(self, X): 878 Id = torch.eye(X.size(0), device=X.device) 879 return torch.linalg.solve(Id + X, Id - X) 880 881 model = nn.Linear(5, 5) 882 parametrize.register_parametrization(model, "weight", Skew()) 883 parametrize.register_parametrization(model, "weight", Orthogonal()) 884 885 to_model = nn.Linear(5, 5) 886 parametrize.transfer_parametrizations_and_params(model, to_model) 887 888 with parametrize.cached(): 889 X = model.weight 890 Y = model.weight 891 self.assertEqual(id(X), id(Y)) 892 893 A = to_model.weight 894 B = to_model.weight 895 self.assertEqual(id(A), id(B)) 896 897 # test that the results are distinct objects for each module 898 self.assertNotEqual(id(A), id(X)) 899 900 @swap([True, False]) 901 def test_parametrization_same_training_mode(self): 902 r"""Test training mode updated on parametrization registration""" 903 904 class Identity(nn.Module): 905 def forward(self, X): 906 return X 907 908 module = nn.Linear(4, 4) 909 module.eval() 910 parametrize.register_parametrization(module, "weight", Identity()) 911 self.assertFalse(module.parametrizations.weight[0].training) 912 module.train() 913 parametrize.register_parametrization(module, "weight", Identity().eval()) 914 self.assertTrue(module.parametrizations.weight[0].training) 915 self.assertTrue(module.parametrizations.weight[1].training) 916 917 @swap([True, False]) 918 def test_type_before_parametrizations(self): 919 r"""Test that type_before_parametrizations always retrieves original type""" 920 921 class Identity(nn.Module): 922 def forward(self, X): 923 return X 924 925 model = nn.Linear(5, 5) 926 original_type = type(model) 927 self.assertTrue( 928 parametrize.type_before_parametrizations(model) == original_type 929 ) 930 parametrize.register_parametrization(model, "weight", Identity()) 931 self.assertTrue( 932 parametrize.type_before_parametrizations(model) == original_type 933 ) 934 935 @swap([True, False]) 936 def test_deepcopy_after_parametrization(self): 937 r"""Test that we are able to create a deepcopy of the module when it's parametrized.""" 938 939 class AddOne(nn.Module): 940 def forward(self, x): 941 return x + 1.0 942 943 class ModelWithoutDeepcopy(nn.Module): 944 def __init__(self) -> None: 945 super().__init__() 946 self.weight = nn.Parameter( 947 torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True 948 ) 949 self.bias = nn.Parameter( 950 torch.tensor([0.0, 0.0, 0.0, 0.0]), requires_grad=True 951 ) 952 self.attr = [1.0, 2.0, 3.0, 4.0] 953 954 class ActualModel(ModelWithoutDeepcopy): 955 # Emulate custom implementation of the deepcopying. 956 def __deepcopy__(self, memo): 957 result = self.__new__(self.__class__) 958 memo[id(self)] = result 959 result.__dict__ = deepcopy(self.__dict__, memo) 960 return result 961 962 def check_deepcopy(m1: nn.Module, m2: nn.Module): 963 w1 = m1.parametrizations.weight.original 964 w2 = m2.parametrizations.weight.original 965 b1 = ( 966 m1.parametrizations.bias.original 967 if parametrize.is_parametrized(m1, "bias") 968 else m1.bias 969 ) 970 b2 = ( 971 m2.parametrizations.bias.original 972 if parametrize.is_parametrized(m2, "bias") 973 else m2.bias 974 ) 975 # Weights, biases and attributes should be equal but they must be different objects. 976 self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys()) 977 self.assertIsNot(m1, m2) 978 self.assertEqual(w1, w2) 979 self.assertIsNot(w1, w2) 980 self.assertEqual(b1, b2) 981 self.assertIsNot(b1, b2) 982 self.assertEqual(m1.attr, m2.attr) 983 self.assertIsNot(m1.attr, m2.attr) 984 985 for model in (ModelWithoutDeepcopy(), ActualModel()): 986 # General check that we are able to create deepcopy. 987 parametrize.register_parametrization(model, "weight", AddOne()) 988 check_deepcopy(model, deepcopy(model)) 989 # Check that this works on models with several parametrized tensors. 990 parametrize.register_parametrization(model, "bias", AddOne()) 991 check_deepcopy(model, deepcopy(model)) 992 # Check that this works on models where tensors have more than one parametrization. 993 parametrize.register_parametrization(model, "weight", AddOne()) 994 check_deepcopy(model, deepcopy(model)) 995 996 @swap([True, False]) 997 def test_transfer_parametrizations_and_params(self): 998 r"""Test that all parametrizations and their associated parameters are transferred.""" 999 1000 class AddOne(nn.Module): 1001 def forward(self, x): 1002 return x + 1.0 1003 1004 class Double(nn.Module): 1005 def forward(self, x): 1006 return 2.0 * x 1007 1008 def right_inverse(self, x): 1009 return 0.5 * x 1010 1011 class MinusOne(nn.Module): 1012 def forward(self, x): 1013 return x - 1.0 1014 1015 model = nn.Linear(5, 5) 1016 parametrize.register_parametrization(model, "weight", AddOne()) 1017 parametrize.register_parametrization(model, "weight", Double()) 1018 parametrize.register_parametrization(model, "weight", MinusOne()) 1019 hold_weight = model.weight 1020 1021 to_model = torch.ao.nn.qat.Linear( 1022 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() 1023 ) 1024 parametrize.transfer_parametrizations_and_params(model, to_model) 1025 1026 # checks that final and original value are correct and the to_model is parametrized 1027 self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) 1028 self.assertEqual(model.weight, to_model.weight) 1029 self.assertEqual( 1030 model.parametrizations.weight.original, 1031 to_model.parametrizations.weight.original, 1032 ) 1033 1034 # check that the transfer didn't affect the original value 1035 self.assertEqual(hold_weight, model.weight) 1036 if get_swap_module_params_on_conversion(): 1037 # When using the swap_tensors path, this is needed so that the autograd 1038 # graph is not alive anymore. 1039 del hold_weight 1040 1041 # testing that changes to one set of parametrizations do not affect the other 1042 parametrize.remove_parametrizations(to_model, "weight") 1043 self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) 1044 self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight")) 1045 1046 # also test that parameters that don't exist in to_model get transferred 1047 model.test_param = Parameter(torch.randn(5, 5)) 1048 1049 self.assertTrue(not hasattr(to_model, "test_param")) 1050 parametrize.register_parametrization(model, "test_param", Double()) 1051 hold_test_param = model.test_param 1052 parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") 1053 1054 # check that previously missing params got transferred correctly 1055 self.assertEqual(model.test_param, to_model.test_param) 1056 self.assertEqual( 1057 model.parametrizations.test_param.original, 1058 to_model.parametrizations.test_param.original, 1059 ) 1060 1061 # check that the new transfer didn't change the value for the from_module 1062 self.assertEqual(hold_test_param, model.test_param) 1063 1064 @swap([True, False]) 1065 def test_transfer_parametrizations_and_params_right_inverse(self): 1066 r"""Test that all parametrizations and their associated parameters are transferred.""" 1067 1068 class Double(nn.Module): 1069 def forward(self, x): 1070 return 2.0 * x 1071 1072 def right_inverse(self, x): 1073 return 0.5 * x 1074 1075 model = nn.Linear(5, 5) 1076 parametrize.register_parametrization(model, "weight", Double()) 1077 hold_weight = model.weight 1078 1079 to_model = torch.ao.nn.qat.Linear( 1080 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() 1081 ) 1082 parametrize.transfer_parametrizations_and_params(model, to_model) 1083 1084 # check that transfer occurs successfully 1085 self.assertEqual(model.weight, to_model.weight) 1086 self.assertEqual( 1087 model.parametrizations.weight.original, 1088 to_model.parametrizations.weight.original, 1089 ) 1090 1091 # check that transfer doesn't affect the from_model weight 1092 self.assertEqual(hold_weight, model.weight) 1093 1094 @swap([True, False]) 1095 def test_transfer_parametrizations_and_params_single_param(self): 1096 r"""Test that all parametrizations and their associated parameters are transferred.""" 1097 1098 class AddOne(nn.Module): 1099 def forward(self, x): 1100 return x + 1.0 1101 1102 class Double(nn.Module): 1103 def forward(self, x): 1104 return 2.0 * x 1105 1106 class MinusOne(nn.Module): 1107 def forward(self, x): 1108 return x - 1.0 1109 1110 model = nn.Linear(5, 5, bias=True) 1111 parametrize.register_parametrization(model, "weight", AddOne()) 1112 parametrize.register_parametrization(model, "weight", Double()) 1113 parametrize.register_parametrization(model, "weight", MinusOne()) 1114 parametrize.register_parametrization(model, "bias", AddOne()) 1115 parametrize.register_parametrization(model, "bias", Double()) 1116 parametrize.register_parametrization(model, "bias", MinusOne()) 1117 1118 to_model = torch.ao.nn.qat.Linear( 1119 5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig() 1120 ) 1121 parametrize.transfer_parametrizations_and_params(model, to_model, "weight") 1122 1123 # check that weight and only weight was transferred 1124 self.assertEqual(model.weight, to_model.weight) 1125 self.assertEqual( 1126 model.parametrizations.weight.original, 1127 to_model.parametrizations.weight.original, 1128 ) 1129 self.assertTrue("bias" not in to_model.parametrizations) 1130 1131 # FIXME: Rewrite this test using functions not depending on LAPACK 1132 # and remove the `@skipIfNoLapack` (see #70995) 1133 @skipIfNoLapack 1134 @swap([True, False]) 1135 def test_transfer_parametrizations_and_params_many_to_one(self): 1136 # A parametrization with several outputs 1137 class RankOne(nn.Module): 1138 def forward(self, x, y): 1139 # Form a rank-1 matrix from a pair of vectors 1140 return x.unsqueeze(-1) @ y.unsqueeze(-2) 1141 1142 def right_inverse(self, Y): 1143 # We project the given matrix onto the rank 1 matrices 1144 U, S, Vh = torch.linalg.svd(Y, full_matrices=False) 1145 # S is ordered in a decreasing way. 1146 s0_sqrt = S[0].sqrt().unsqueeze(-1) 1147 return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt 1148 1149 class Double(nn.Module): 1150 def forward(self, x): 1151 return 2.0 * x 1152 1153 model = nn.Linear(3, 3) 1154 parametrize.register_parametrization(model, "weight", RankOne()) 1155 parametrize.register_parametrization(model, "weight", Double()) 1156 hold_weight = model.weight 1157 1158 to_model = torch.ao.nn.qat.Linear( 1159 3, 3, qconfig=torch.ao.quantization.get_default_qconfig() 1160 ) 1161 1162 parametrize.transfer_parametrizations_and_params(model, to_model) 1163 1164 # checks that final and original value are correct and the to_model is parametrized 1165 self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) 1166 self.assertEqual(model.weight, to_model.weight) 1167 self.assertEqual( 1168 model.parametrizations.weight.original0, 1169 to_model.parametrizations.weight.original0, 1170 ) 1171 self.assertEqual( 1172 model.parametrizations.weight.original1, 1173 to_model.parametrizations.weight.original1, 1174 ) 1175 1176 # check that the transfer didn't affect the original value 1177 self.assertEqual(hold_weight, model.weight) 1178 1179 # testing that changes to one set of parametrizations do not affect the other 1180 model.test_param = Parameter(torch.randn(3, 3)) 1181 1182 self.assertTrue(not hasattr(to_model, "test_param")) 1183 parametrize.register_parametrization(model, "test_param", RankOne()) 1184 hold_test_param = model.test_param 1185 parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") 1186 1187 # also check that previously missing params got transferred correctly 1188 self.assertEqual(model.test_param, to_model.test_param) 1189 self.assertEqual( 1190 model.parametrizations.test_param.original0, 1191 to_model.parametrizations.test_param.original0, 1192 ) 1193 self.assertEqual( 1194 model.parametrizations.test_param.original1, 1195 to_model.parametrizations.test_param.original1, 1196 ) 1197 1198 # check that the new transfer didn't change the value for the from_module 1199 self.assertEqual(hold_test_param, model.test_param) 1200 1201 @swap([True, False]) 1202 def test_new_spectral_norm(self): 1203 with set_default_dtype(torch.double): 1204 input = torch.randn(3, 5) 1205 m = nn.Linear(5, 7) 1206 m = torch.nn.utils.parametrizations.spectral_norm(m) 1207 spectral_norm_m = m.parametrizations.weight[0] 1208 1209 self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)])) 1210 1211 # .parametrizations.weight.original should be trainable 1212 self.assertTrue(hasattr(m.parametrizations.weight, "original")) 1213 self.assertTrue("original" in m.parametrizations.weight._parameters) 1214 1215 # u should be just a reused buffer 1216 self.assertTrue(hasattr(spectral_norm_m, "_u")) 1217 self.assertTrue("_u" in spectral_norm_m._buffers) 1218 self.assertTrue("_v" in spectral_norm_m._buffers) 1219 1220 # weight should be a plain attribute, not counted as a buffer or a param 1221 self.assertIsNotNone(m.weight) 1222 self.assertFalse("weight" in m._buffers) 1223 self.assertFalse("weight" in m._parameters) 1224 1225 # it should also be sharing storage as `weight_orig` 1226 # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage()) 1227 self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size()) 1228 self.assertEqual( 1229 m.parametrizations.weight.original.stride(), m.weight.stride() 1230 ) 1231 1232 m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1233 1234 # spectral_norm is the only parametrization 1235 self.assertFalse(hasattr(m, "parametrizations")) 1236 self.assertTrue("weight" in m._parameters) 1237 1238 # We can register spectral_norm multiple times on the same parameter 1239 # and on multiple parameters in the same module 1240 m = torch.nn.utils.parametrizations.spectral_norm(m, "weight") 1241 m = torch.nn.utils.parametrizations.spectral_norm(m, "weight") 1242 m = torch.nn.utils.parametrizations.spectral_norm(m, "bias") 1243 1244 # If we remove the parametrization on bias, weight is still parametrized 1245 # Removing a parametrization runs forward in eval mode if leave_parametrized=True 1246 m = torch.nn.utils.parametrize.remove_parametrizations(m, "bias") 1247 self.assertTrue("bias" in m._parameters) 1248 self.assertTrue(hasattr(m, "parametrizations")) 1249 self.assertFalse("weight" in m._parameters) 1250 1251 m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1252 # Neither weight and bias are parametrized 1253 self.assertFalse(hasattr(m, "parametrizations")) 1254 self.assertTrue("weight" in m._parameters) 1255 self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m)) 1256 1257 # test correctness in training/eval modes and cpu/multi-gpu settings 1258 for apply_dp in (True, False): 1259 if apply_dp: 1260 if not TEST_MULTIGPU: 1261 continue 1262 device = torch.device("cuda:0") 1263 1264 def maybe_wrap(m): 1265 return torch.nn.DataParallel(m, [0, 1]) 1266 1267 else: 1268 device = torch.device("cpu") 1269 1270 def maybe_wrap(m): 1271 return m 1272 1273 for requires_grad in (True, False): 1274 1275 def get_modules(): 1276 m = nn.Linear(3, 4).to(device) 1277 m.weight.requires_grad_(requires_grad) 1278 m = torch.nn.utils.parametrizations.spectral_norm(m) 1279 wrapped_m = maybe_wrap(m) 1280 spectral_norm_m = m.parametrizations.weight[0] 1281 return m, wrapped_m, spectral_norm_m 1282 1283 input = torch.randn(2, 3, device=device) 1284 1285 m, wrapped_m, spectral_norm_m = get_modules() 1286 1287 self.assertTrue(hasattr(spectral_norm_m, "_u")) 1288 u0 = spectral_norm_m._u.clone() 1289 v0 = spectral_norm_m._v.clone() 1290 1291 # TEST TRAINING BEHAVIOR 1292 1293 # We perform GD first to modify the initial matrix 1294 opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1) 1295 1296 opt.zero_grad() 1297 wrapped_m(input).sum().backward() 1298 opt.step() 1299 1300 out = wrapped_m(input) 1301 if requires_grad: 1302 # run forward again and assert that u and v are updated 1303 self.assertNotEqual(u0, spectral_norm_m._u) 1304 self.assertNotEqual(v0, spectral_norm_m._v) 1305 1306 # assert that backprop reaches original weight 1307 # can't use gradcheck because the function changes as we 1308 # activate through it in training mode 1309 if requires_grad: 1310 torch.autograd.grad( 1311 out.sum(), m.parametrizations.weight.original 1312 ) 1313 1314 # test backward works with multiple forwards 1315 # it uses training mode so we need to reset `u` and `v` vectors 1316 # to same value at beginning for finite difference test to pass 1317 saved_u = spectral_norm_m._u.clone() 1318 saved_v = spectral_norm_m._v.clone() 1319 1320 def fn(input): 1321 spectral_norm_m._u.data.copy_(saved_u) 1322 spectral_norm_m._v.data.copy_(saved_v) 1323 out0 = wrapped_m(input) 1324 out1 = wrapped_m(input) 1325 return out0 + out1 1326 1327 # Make sure we can compute gradients wrt to all the parameters in the case 1328 # of double forward 1329 fn(input.clone().requires_grad_()).sum().backward() 1330 gradcheck( 1331 fn, (input.clone().requires_grad_(),), check_batched_grad=False 1332 ) 1333 1334 # test removing 1335 # spectral norm module needs to be in eval mode if we'd like to 1336 # avoid doing another power iteration 1337 m, wrapped_m, _ = get_modules() 1338 pre_remove_out = wrapped_m(input) 1339 if get_swap_module_params_on_conversion(): 1340 # When using the swap_tensors path, this is needed so that the autograd 1341 # graph is not alive anymore. 1342 pre_remove_out_ref = pre_remove_out.detach() 1343 del pre_remove_out 1344 else: 1345 pre_remove_out_ref = pre_remove_out 1346 m.eval() 1347 m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1348 self.assertEqual(wrapped_m(input), pre_remove_out_ref) 1349 1350 torch.nn.utils.parametrizations.spectral_norm(m) 1351 for _ in range(3): 1352 pre_remove_out = wrapped_m(input) 1353 if get_swap_module_params_on_conversion(): 1354 # When using the swap_tensors path, this is needed so that the autograd 1355 # graph is not alive anymore. 1356 pre_remove_out_ref = pre_remove_out.detach() 1357 del pre_remove_out 1358 else: 1359 pre_remove_out_ref = pre_remove_out 1360 m.eval() 1361 m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1362 self.assertEqual(wrapped_m(input), pre_remove_out_ref) 1363 1364 # TEST EVAL BEHAVIOR 1365 m, wrapped_m, spectral_norm_m = get_modules() 1366 wrapped_m(input) 1367 last_train_out = wrapped_m(input) 1368 last_train_u = spectral_norm_m._u.clone() 1369 last_train_v = spectral_norm_m._v.clone() 1370 wrapped_m.zero_grad() 1371 wrapped_m.eval() 1372 1373 eval_out0 = wrapped_m(input) 1374 # assert eval gives same result as last training iteration 1375 self.assertEqual(eval_out0, last_train_out) 1376 # assert doing more iteartion in eval don't change things 1377 self.assertEqual(eval_out0, wrapped_m(input)) 1378 self.assertEqual(last_train_u, spectral_norm_m._u) 1379 self.assertEqual(last_train_v, spectral_norm_m._v) 1380 1381 # FIXME: the code below is flaky when executed with DataParallel 1382 # see https://github.com/pytorch/pytorch/issues/13818 1383 if apply_dp: 1384 continue 1385 1386 # test backward works with multiple forwards in mixed training 1387 # and eval modes 1388 # it uses training mode so we need to reset `u` and `v` vectors 1389 # to same value at beginning for finite difference test to pass 1390 saved_u = spectral_norm_m._u.clone() 1391 saved_v = spectral_norm_m._v.clone() 1392 1393 def fn(input): 1394 spectral_norm_m._u.data.copy_(saved_u) 1395 spectral_norm_m._v.data.copy_(saved_v) 1396 wrapped_m.train() 1397 out0 = wrapped_m(input) 1398 wrapped_m.eval() 1399 out1 = wrapped_m(input) 1400 wrapped_m.train() 1401 out2 = wrapped_m(input) 1402 wrapped_m.eval() 1403 out3 = wrapped_m(input) 1404 return out0 + out1 + out2 + out3 1405 1406 gradcheck(fn, (input.clone().requires_grad_(),)) 1407 1408 # assert that backprop reaches weight_orig in eval 1409 if requires_grad: 1410 1411 def fn(weight): 1412 return wrapped_m(input) 1413 1414 gradcheck(fn, (m.parametrizations.weight.original,)) 1415 1416 def test_register_parametrization_no_grad(self): 1417 r"""Test that it is possible to register a parametrization without gradient""" 1418 1419 class SplitAndCat(nn.Module): 1420 def right_inverse(self, x): 1421 # split the tensor in two halfs 1422 return torch.split(x, x.shape[1] // 2) 1423 1424 def forward(self, x0, x1): 1425 return torch.cat([x0, x1]) 1426 1427 model = nn.Linear(8, 8) 1428 1429 model.weight.requires_grad = False 1430 parametrize.register_parametrization(model, "weight", SplitAndCat()) 1431 # making sure the parameterized and decomposed Tensors both have requires_grad == False 1432 self.assertFalse(model.weight.requires_grad) 1433 self.assertFalse(model.parametrizations.weight.original0.requires_grad) 1434 self.assertFalse(model.parametrizations.weight.original1.requires_grad) 1435 1436 @swap([True, False]) 1437 def test_new_spectral_norm_load_state_dict(self): 1438 for activate_times in (0, 3): 1439 inp = torch.randn(2, 3) 1440 m = nn.Linear(3, 5) 1441 snm = torch.nn.utils.parametrizations.spectral_norm(m) 1442 snm.train() 1443 1444 for _ in range(activate_times): 1445 snm(inp) 1446 1447 state_dict = deepcopy(snm.state_dict()) 1448 self.assertEqual( 1449 { 1450 "parametrizations.weight.original", 1451 "bias", 1452 "parametrizations.weight.0._v", 1453 "parametrizations.weight.0._u", 1454 }, 1455 set(state_dict.keys()), 1456 ) 1457 1458 # test that non-strict loading works 1459 non_strict_state_dict = deepcopy(state_dict) 1460 non_strict_state_dict["nonsense"] = "nonsense" 1461 with self.assertRaisesRegex( 1462 RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"' 1463 ): 1464 snm.load_state_dict(non_strict_state_dict, strict=True) 1465 snm.load_state_dict(non_strict_state_dict, strict=False) 1466 del non_strict_state_dict["parametrizations.weight.original"] 1467 snm.load_state_dict(non_strict_state_dict, strict=False) 1468 del non_strict_state_dict["parametrizations.weight.0._u"] 1469 snm.load_state_dict(non_strict_state_dict, strict=False) 1470 del non_strict_state_dict["parametrizations.weight.0._v"] 1471 snm.load_state_dict(non_strict_state_dict, strict=False) 1472 non_strict_state_dict[ 1473 "weight" 1474 ] = snm.weight.detach().clone() # set W as a buffer 1475 snm.load_state_dict(non_strict_state_dict, strict=False) 1476 del non_strict_state_dict._metadata[ 1477 "parametrizations.weight.0" 1478 ] # remove metadata info 1479 snm.load_state_dict(non_strict_state_dict, strict=False) 1480 del non_strict_state_dict["weight"] # remove W buffer 1481 snm.load_state_dict(non_strict_state_dict, strict=False) 1482 del non_strict_state_dict["bias"] 1483 snm.load_state_dict(non_strict_state_dict, strict=False) 1484 1485 # normal state_dict 1486 1487 # test that re-wrapping does not matter 1488 m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight") 1489 snm = torch.nn.utils.parametrizations.spectral_norm(m) 1490 1491 snm.load_state_dict(state_dict) 1492 with torch.no_grad(): 1493 snm.eval() 1494 out0_eval = snm(inp) 1495 snm.train() 1496 out1_train = snm(inp) 1497 out2_train = snm(inp) 1498 snm.eval() 1499 out3_eval = snm(inp) 1500 1501 # test that re-wrapping does not matter 1502 m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight") 1503 snm = torch.nn.utils.parametrizations.spectral_norm(m) 1504 1505 # Test normal loading 1506 snm.load_state_dict(state_dict) 1507 with torch.no_grad(): 1508 snm.eval() 1509 self.assertEqual(out0_eval, snm(inp)) 1510 snm.train() 1511 self.assertEqual(out1_train, snm(inp)) 1512 self.assertEqual(out2_train, snm(inp)) 1513 snm.eval() 1514 self.assertEqual(out3_eval, snm(inp)) 1515 1516 @swap([True, False]) 1517 def test_new_spectral_norm_dim(self): 1518 inp = torch.randn(2, 3, 10, 12) 1519 m = nn.ConvTranspose2d(3, 4, (5, 6)) 1520 m = torch.nn.utils.parametrizations.spectral_norm(m) 1521 snm = m.parametrizations.weight[0] 1522 # this should not run into incompatible shapes 1523 x = m(inp) 1524 # check that u refers to the same dimension 1525 self.assertEqual( 1526 snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape 1527 ) 1528 1529 @swap([True, False]) 1530 def test_new_spectral_norm_forward(self): 1531 input = torch.randn(3, 5) 1532 m = nn.Linear(5, 7) 1533 m = torch.nn.utils.parametrizations.spectral_norm(m) 1534 snm = m.parametrizations.weight[0] 1535 # naive forward 1536 _weight = m.parametrizations.weight.original 1537 _bias, _v = m.bias, snm._v 1538 _weight_mat = _weight.view(_weight.size(0), -1) 1539 _u = torch.mv(_weight_mat, _v) 1540 _u = F.normalize(_u, dim=0, eps=1e-12) 1541 _v = torch.mv(_weight_mat.t(), _u) 1542 _v = F.normalize(_v, dim=0, eps=1e-12) 1543 _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v)) 1544 out_hat = torch.nn.functional.linear(input, _weight, _bias) 1545 expect_out = m(input) 1546 self.assertEqual(expect_out, out_hat) 1547 1548 @swap([True, False]) 1549 @skipIfTorchDynamo("Test does not work with TorchDynamo") 1550 def test_new_spectral_norm_value(self): 1551 # a test that the spectral norm (= top singular value) 1552 # is in fact properly calculated, using example of a simple diagonal matrix. 1553 for dtype in (torch.float, torch.cfloat): 1554 m = nn.Linear(2, 2, dtype=dtype) 1555 with torch.no_grad(): 1556 # set weight to be diagonal 1557 x = torch.diagonal(m.weight) 1558 m.weight = nn.Parameter(torch.diag(x)) 1559 torch.nn.utils.parametrizations.spectral_norm(m) 1560 # weights should be rescaled by spectral norm, (i.e., largest diagonal element in norm) 1561 expected = torch.diag(x / x.abs().max()) 1562 self.assertEqual(m.weight.data, expected) 1563 1564 @skipIfNoLapack 1565 @swap([True, False]) 1566 def test_orthogonal_parametrization(self): 1567 # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) 1568 1569 def assert_is_orthogonal(X): 1570 n, k = X.size(-2), X.size(-1) 1571 if n < k: 1572 X = X.mT 1573 n, k = k, n 1574 Id = torch.eye(k, dtype=X.dtype, device=X.device).expand( 1575 *(X.size()[:-2]), k, k 1576 ) 1577 eps = 10 * n * torch.finfo(X.dtype).eps 1578 torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.0) 1579 1580 def assert_weight_allclose_Q(weight, W): 1581 # Test that weight is equal to the Q part of the QR decomposition of W 1582 # (or of its transpose if the matrix is wide) 1583 wide_matrix = W.size(-2) < W.size(-1) 1584 if wide_matrix: 1585 W = W.mT 1586 Q, R = torch.linalg.qr(W) 1587 Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) 1588 if wide_matrix: 1589 Q = Q.mT 1590 torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.0) 1591 1592 for shape, dtype, use_linear in product( 1593 ((4, 4), (5, 3), (3, 5)), # square/ tall / wide 1594 (torch.float32, torch.complex64), 1595 (True, False), 1596 ): 1597 # Conv2d does not support complex yet 1598 if not use_linear: 1599 continue 1600 1601 if use_linear: 1602 input = torch.randn(3, shape[0], dtype=dtype) 1603 else: 1604 input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype) 1605 1606 for parametrization, use_trivialization in product( 1607 ("matrix_exp", "cayley", "householder"), (False, True) 1608 ): 1609 # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False 1610 # See Note [right_inverse expm cayley] 1611 can_initialize = use_trivialization or parametrization == "householder" 1612 1613 # We generate them every time to always start with fresh weights 1614 if use_linear: 1615 m = nn.Linear(*shape, dtype=dtype) 1616 else: 1617 m = nn.Conv2d(2, 3, shape, dtype=dtype) 1618 1619 # We do not support householder for complex inputs 1620 # See Note [Householder complex] 1621 1622 # When using the swap_tensors path, this is needed so that the autograd 1623 # graph is not alive anymore. 1624 if get_swap_module_params_on_conversion(): 1625 w_init = m.weight.clone().detach() 1626 else: 1627 w_init = m.weight.clone() 1628 if parametrization == "householder" and m.weight.is_complex(): 1629 msg = "householder parametrization does not support complex tensors" 1630 with self.assertRaisesRegex(ValueError, msg): 1631 torch.nn.utils.parametrizations.orthogonal( 1632 m, 1633 "weight", 1634 parametrization, 1635 use_trivialization=use_trivialization, 1636 ) 1637 continue 1638 1639 wide_matrix = w_init.size(-2) < w_init.size(-1) 1640 torch.nn.utils.parametrizations.orthogonal( 1641 m, "weight", parametrization, use_trivialization=use_trivialization 1642 ) 1643 # Forwards works as expected 1644 self.assertEqual(w_init.shape, m.weight.shape) 1645 assert_is_orthogonal(m.weight) 1646 if can_initialize: 1647 assert_weight_allclose_Q(m.weight, w_init) 1648 1649 # Intializing with a given orthogonal matrix works 1650 X = torch.randn_like(m.weight) 1651 if wide_matrix: 1652 X = X.mT 1653 w_new = torch.linalg.qr(X).Q 1654 if wide_matrix: 1655 w_new = w_new.mT 1656 if can_initialize: 1657 m.weight = w_new 1658 torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.0) 1659 else: 1660 msg = ( 1661 "assign to the matrix exponential or the Cayley parametrization" 1662 ) 1663 with self.assertRaisesRegex(NotImplementedError, msg): 1664 m.weight = w_new 1665 1666 # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix 1667 w_new = torch.randn_like(m.weight) 1668 if can_initialize: 1669 m.weight = w_new 1670 assert_weight_allclose_Q(m.weight, w_new) 1671 else: 1672 msg = ( 1673 "assign to the matrix exponential or the Cayley parametrization" 1674 ) 1675 with self.assertRaisesRegex(NotImplementedError, msg): 1676 m.weight = w_new 1677 1678 opt = torch.optim.SGD(m.parameters(), lr=0.1) 1679 for _ in range(2): 1680 opt.zero_grad() 1681 m(input).norm().backward() 1682 grad = m.parametrizations.weight.original.grad 1683 self.assertIsNotNone(grad) 1684 # We do not update the upper triangular part of the matrix if tall tril if wide 1685 if grad.size(-2) >= grad.size(-1): 1686 zeros_grad = grad.triu(1) 1687 else: 1688 zeros_grad = grad.tril(-1) 1689 self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad)) 1690 # The gradient in the diagonal can only be imaginary because a skew-Hermitian 1691 # matrix has imaginary diagonal 1692 diag_grad = grad.diagonal(dim1=-2, dim2=-1) 1693 if grad.is_complex(): 1694 diag_grad = diag_grad.real 1695 self.assertEqual(diag_grad, torch.zeros_like(diag_grad)) 1696 opt.step() 1697 assert_is_orthogonal(m.weight) 1698 1699 @skipIfNoLapack 1700 @swap([True, False]) 1701 def test_orthogonal_errors(self): 1702 m = nn.Linear(3, 4) 1703 with self.assertRaisesRegex(ValueError, "has to be one of"): 1704 torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo") 1705 1706 with self.assertRaisesRegex(ValueError, "Expected a matrix"): 1707 torch.nn.utils.parametrizations.orthogonal(m, "bias") 1708 1709 torch.nn.utils.parametrizations.orthogonal(m, "weight") 1710 with self.assertRaisesRegex(ValueError, "matrices of shape"): 1711 m.weight = torch.randn(5, 5) 1712 torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1713 1714 @swap([True, False]) 1715 def test_weight_norm_state_dict_compat(self): 1716 m = nn.Linear(4, 5) 1717 m = torch.nn.utils.weight_norm(m) 1718 old_dict = m.state_dict() 1719 1720 m2 = nn.Linear(4, 5) 1721 m2 = torch.nn.utils.parametrizations.weight_norm(m2) 1722 m2.load_state_dict(old_dict) 1723 1724 input = torch.randn(3, 4) 1725 self.assertEqual(m(input), m2(input)) 1726 1727 @swap([True, False]) 1728 def test_weight_norm_pickle(self): 1729 m = nn.Linear(4, 5) 1730 m = torch.nn.utils.parametrizations.weight_norm(m) 1731 with self.assertRaisesRegex(RuntimeError, "state_dict"): 1732 pickle.dumps(m) 1733 1734 @swap([True, False]) 1735 def test_weight_norm_deepcopy(self): 1736 m = nn.Linear(4, 5) 1737 m = torch.nn.utils.parametrizations.weight_norm(m) 1738 m2 = deepcopy(m) 1739 input = torch.randn(3, 4) 1740 self.assertEqual(m(input), m2(input)) 1741 1742 @swap([True]) 1743 def test_wrapper_subclass_parametrization(self): 1744 class Subclassify(nn.Module): 1745 def forward(self, X): 1746 return TwoTensor(X, X) 1747 1748 class UnSubclassify(nn.Module): 1749 def forward(self, X): 1750 return X.a 1751 1752 class IdentityWithRightInverse(nn.Module): 1753 def forward(self, X): 1754 return X 1755 1756 def right_inverse(self, X): 1757 return TwoTensor(X, X) 1758 1759 def _check_parametrization( 1760 parametrization, 1761 type_before_registration, 1762 type_after_registration, 1763 leave_parametrized=False, 1764 type_after_right_inverse=None, 1765 ): 1766 model = nn.Linear(2, 2) 1767 buf = torch.randn(2, 2) 1768 model.buf = torch.nn.Buffer(buf) 1769 if ( 1770 type_before_registration == TwoTensor 1771 and type_after_registration == Tensor 1772 ): 1773 model._apply(lambda t: TwoTensor(t, t)) 1774 initial_weight = model.weight.clone().detach() 1775 initial_weight_id = id(model.weight) 1776 initial_buf = model.buf.clone().detach() 1777 initial_buf_id = id(model.buf) 1778 type_original_weight = ( 1779 type_before_registration 1780 if type_after_right_inverse is None 1781 else type_after_right_inverse 1782 ) 1783 type_original_buf = ( 1784 Tensor if type_original_weight is nn.Parameter else type_original_weight 1785 ) 1786 type_after_removal_buf = ( 1787 type_after_registration if leave_parametrized else type_original_buf 1788 ) 1789 if leave_parametrized: 1790 if type_after_registration is Tensor: 1791 type_after_removal_weight = nn.Parameter 1792 else: 1793 type_after_removal_weight = type_after_registration 1794 else: 1795 type_after_removal_weight = type_original_weight 1796 1797 parametrize.register_parametrization(model, "weight", parametrization()) 1798 parametrize.register_parametrization(model, "buf", parametrization()) 1799 self.assertTrue(hasattr(model, "parametrizations")) 1800 self.assertTrue(parametrize.is_parametrized(model)) 1801 self.assertFalse(parametrize.is_parametrized(model, "bias")) 1802 # checks for weight 1803 self.assertTrue(parametrize.is_parametrized(model, "weight")) 1804 self.assertTrue( 1805 isinstance(model.parametrizations.weight.original, nn.Parameter) 1806 ) 1807 self.assertTrue( 1808 type(model.parametrizations.weight.original) is type_original_weight 1809 ) 1810 self.assertNotIn("weight", model._parameters) 1811 self.assertTrue(type(model.weight) is type_after_registration) 1812 # checks for buf 1813 self.assertTrue(parametrize.is_parametrized(model, "buf")) 1814 self.assertFalse( 1815 isinstance(model.parametrizations.buf.original, nn.Parameter) 1816 ) 1817 self.assertTrue( 1818 type(model.parametrizations.buf.original) is type_original_buf 1819 ) 1820 self.assertTrue(type(model.buf) is type_after_registration) 1821 parametrize.remove_parametrizations( 1822 model, "weight", leave_parametrized=leave_parametrized 1823 ) 1824 parametrize.remove_parametrizations( 1825 model, "buf", leave_parametrized=leave_parametrized 1826 ) 1827 self.assertFalse(hasattr(model, "parametrizations")) 1828 self.assertEqual(model.__class__, nn.Linear) 1829 # checks for weight 1830 self.assertTrue(type(model.weight) is type_after_removal_weight) 1831 self.assertTrue(isinstance(model.weight, nn.Parameter)) 1832 self.assertEqual(id(model.weight), initial_weight_id) 1833 # checks for buf 1834 self.assertTrue(type(model.buf) is type_after_removal_buf) 1835 self.assertFalse(isinstance(model.buf, nn.Parameter)) 1836 self.assertEqual(id(model.buf), initial_buf_id) 1837 if not leave_parametrized and type_after_right_inverse is None: 1838 self.assertEqual(model.weight, initial_weight) 1839 self.assertEqual(model.buf, initial_buf) 1840 1841 _check_parametrization(Subclassify, nn.Parameter, TwoTensor) 1842 _check_parametrization(UnSubclassify, TwoTensor, Tensor) 1843 _check_parametrization( 1844 IdentityWithRightInverse, 1845 nn.Parameter, 1846 TwoTensor, 1847 type_after_right_inverse=TwoTensor, 1848 ) 1849 _check_parametrization( 1850 Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True 1851 ) 1852 _check_parametrization( 1853 UnSubclassify, TwoTensor, Tensor, leave_parametrized=True 1854 ) 1855 _check_parametrization( 1856 IdentityWithRightInverse, 1857 nn.Parameter, 1858 TwoTensor, 1859 leave_parametrized=True, 1860 type_after_right_inverse=TwoTensor, 1861 ) 1862 1863 1864class TestNNParametrizationDevice(NNTestCase): 1865 @swap([True, False]) 1866 def test_weight_norm_parametrization(self, device): 1867 for dtype in [torch.float, torch.bfloat16]: 1868 input = torch.randn(3, 4, dtype=dtype, device=device) 1869 m = nn.Linear(4, 5, dtype=dtype, device=device) 1870 expected_output = m(input) 1871 1872 # add weight normalization 1873 m = torch.nn.utils.parametrizations.weight_norm(m) 1874 self.assertEqual( 1875 m.parametrizations.weight.original1.size(), m.weight.size() 1876 ) 1877 self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1)) 1878 self.assertEqual(m(input), expected_output) 1879 1880 # remove weight norm 1881 torch.nn.utils.parametrize.remove_parametrizations(m, "weight") 1882 self.assertFalse(hasattr(m, "parametrizations")) 1883 self.assertEqual(m(input), expected_output) 1884 1885 # test with dim=1 1886 m = torch.nn.utils.parametrizations.weight_norm(m, dim=1) 1887 self.assertEqual( 1888 m.parametrizations.weight.original1.size(), m.weight.size() 1889 ) 1890 self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4)) 1891 self.assertEqual(m(input), expected_output) 1892 1893 # test with dim=None 1894 m = nn.Linear(4, 5, dtype=dtype, device=device) 1895 expected_output = m(input) 1896 m = torch.nn.utils.parametrizations.weight_norm(m, dim=None) 1897 self.assertEqual(m(input), expected_output) 1898 1899 1900only_for = ("cpu", "cuda") 1901instantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for) 1902instantiate_parametrized_tests(TestNNParametrization) 1903 1904if __name__ == "__main__": 1905 run_tests() 1906