1# Owner(s): ["module: nn"] 2import pickle 3import unittest 4import unittest.mock as mock 5 6import torch 7import torch.nn as nn 8import torch.nn.utils.prune as prune 9from torch.testing._internal.common_nn import NNTestCase 10from torch.testing._internal.common_utils import ( 11 instantiate_parametrized_tests, 12 run_tests, 13 TemporaryFileName, 14 TEST_NUMPY, 15) 16 17 18class TestPruningNN(NNTestCase): 19 _do_cuda_memory_leak_check = True 20 _do_cuda_non_default_stream = True 21 22 # torch/nn/utils/prune.py 23 @unittest.skipIf(not TEST_NUMPY, "numpy not found") 24 def test_validate_pruning_amount_init(self): 25 r"""Test the first util function that validates the pruning 26 amount requested by the user the moment the pruning method 27 is initialized. This test checks that the expected errors are 28 raised whenever the amount is invalid. 29 The original function runs basic type checking + value range checks. 30 It doesn't check the validity of the pruning amount with 31 respect to the size of the tensor to prune. That's left to 32 `_validate_pruning_amount`, tested below. 33 """ 34 # neither float not int should raise TypeError 35 with self.assertRaises(TypeError): 36 prune._validate_pruning_amount_init(amount="I'm a string") 37 38 # float not in [0, 1] should raise ValueError 39 with self.assertRaises(ValueError): 40 prune._validate_pruning_amount_init(amount=1.1) 41 with self.assertRaises(ValueError): 42 prune._validate_pruning_amount_init(amount=20.0) 43 44 # negative int should raise ValueError 45 with self.assertRaises(ValueError): 46 prune._validate_pruning_amount_init(amount=-10) 47 48 # all these should pass without errors because they're valid amounts 49 prune._validate_pruning_amount_init(amount=0.34) 50 prune._validate_pruning_amount_init(amount=1500) 51 prune._validate_pruning_amount_init(amount=0) 52 prune._validate_pruning_amount_init(amount=0.0) 53 prune._validate_pruning_amount_init(amount=1) 54 prune._validate_pruning_amount_init(amount=1.0) 55 self.assertTrue(True) 56 57 @unittest.skipIf(not TEST_NUMPY, "numpy not found") 58 def test_validate_pruning_amount(self): 59 r"""Tests the second util function that validates the pruning 60 amount requested by the user, this time with respect to the size 61 of the tensor to prune. The rationale is that if the pruning amount, 62 converted to absolute value of units to prune, is larger than 63 the number of units in the tensor, then we expect the util function 64 to raise a value error. 65 """ 66 # if amount is int and amount > tensor_size, raise ValueError 67 with self.assertRaises(ValueError): 68 prune._validate_pruning_amount(amount=20, tensor_size=19) 69 70 # amount is a float so this should not raise an error 71 prune._validate_pruning_amount(amount=0.3, tensor_size=0) 72 73 # this is okay 74 prune._validate_pruning_amount(amount=19, tensor_size=20) 75 prune._validate_pruning_amount(amount=0, tensor_size=0) 76 prune._validate_pruning_amount(amount=1, tensor_size=1) 77 self.assertTrue(True) 78 79 @unittest.skipIf(not TEST_NUMPY, "numpy not found") 80 def test_compute_nparams_to_prune(self): 81 r"""Test that requested pruning `amount` gets translated into the 82 correct absolute number of units to prune. 83 """ 84 self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0) 85 self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10) 86 # if 1 is int, means 1 unit 87 self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1) 88 # if 1. is float, means 100% of units 89 self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15) 90 self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7) 91 92 def test_random_pruning_sizes(self): 93 r"""Test that the new parameters and buffers created by the pruning 94 method have the same size as the input tensor to prune. These, in 95 fact, correspond to the pruned version of the tensor itself, its 96 mask, and its original copy, so the size must match. 97 """ 98 # fixturize test 99 # TODO: add other modules 100 modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 101 names = ["weight", "bias"] 102 103 for m in modules: 104 for name in names: 105 with self.subTest(m=m, name=name): 106 original_tensor = getattr(m, name) 107 108 prune.random_unstructured(m, name=name, amount=0.1) 109 # mask has the same size as tensor being pruned 110 self.assertEqual( 111 original_tensor.size(), getattr(m, name + "_mask").size() 112 ) 113 # 'orig' tensor has the same size as the original tensor 114 self.assertEqual( 115 original_tensor.size(), getattr(m, name + "_orig").size() 116 ) 117 # new tensor has the same size as the original tensor 118 self.assertEqual(original_tensor.size(), getattr(m, name).size()) 119 120 def test_random_pruning_orig(self): 121 r"""Test that original tensor is correctly stored in 'orig' 122 after pruning is applied. Important to make sure we don't 123 lose info about the original unpruned parameter. 124 """ 125 # fixturize test 126 # TODO: add other modules 127 modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 128 names = ["weight", "bias"] 129 130 for m in modules: 131 for name in names: 132 with self.subTest(m=m, name=name): 133 # tensor prior to pruning 134 original_tensor = getattr(m, name) 135 prune.random_unstructured(m, name=name, amount=0.1) 136 self.assertEqual(original_tensor, getattr(m, name + "_orig")) 137 138 def test_random_pruning_new_weight(self): 139 r"""Test that module.name now contains a pruned version of 140 the original tensor obtained from multiplying it by the mask. 141 """ 142 # fixturize test 143 # TODO: add other modules 144 modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 145 names = ["weight", "bias"] 146 147 for m in modules: 148 for name in names: 149 with self.subTest(m=m, name=name): 150 # tensor prior to pruning 151 original_tensor = getattr(m, name) 152 prune.random_unstructured(m, name=name, amount=0.1) 153 # weight = weight_orig * weight_mask 154 self.assertEqual( 155 getattr(m, name), 156 getattr(m, name + "_orig") 157 * getattr(m, name + "_mask").to(dtype=original_tensor.dtype), 158 ) 159 160 def test_identity_pruning(self): 161 r"""Test that a mask of 1s does not change forward or backward.""" 162 input_ = torch.ones(1, 5) 163 m = nn.Linear(5, 2) 164 y_prepruning = m(input_) # output prior to pruning 165 166 # compute grad pre-pruning and check it's equal to all ones 167 y_prepruning.sum().backward() 168 old_grad_weight = m.weight.grad.clone() # don't grab pointer! 169 self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) 170 old_grad_bias = m.bias.grad.clone() 171 self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) 172 173 # remove grads 174 m.zero_grad() 175 176 # force the mask to be made of all 1s 177 prune.identity(m, name="weight") 178 179 # with mask of 1s, output should be identical to no mask 180 y_postpruning = m(input_) 181 self.assertEqual(y_prepruning, y_postpruning) 182 183 # with mask of 1s, grad should be identical to no mask 184 y_postpruning.sum().backward() 185 self.assertEqual(old_grad_weight, m.weight_orig.grad) 186 self.assertEqual(old_grad_bias, m.bias.grad) 187 188 # calling forward twice in a row shouldn't change output 189 y1 = m(input_) 190 y2 = m(input_) 191 self.assertEqual(y1, y2) 192 193 def test_random_pruning_0perc(self): 194 r"""Test that a mask of 1s does not change forward or backward.""" 195 input_ = torch.ones(1, 5) 196 m = nn.Linear(5, 2) 197 y_prepruning = m(input_) # output prior to pruning 198 199 # compute grad pre-pruning and check it's equal to all ones 200 y_prepruning.sum().backward() 201 old_grad_weight = m.weight.grad.clone() # don't grab pointer! 202 self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) 203 old_grad_bias = m.bias.grad.clone() 204 self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) 205 206 # remove grads 207 m.zero_grad() 208 209 # force the mask to be made of all 1s 210 with mock.patch( 211 "torch.nn.utils.prune.RandomUnstructured.compute_mask" 212 ) as compute_mask: 213 compute_mask.return_value = torch.ones_like(m.weight) 214 prune.random_unstructured( 215 m, name="weight", amount=0.9 216 ) # amount won't count 217 218 # with mask of 1s, output should be identical to no mask 219 y_postpruning = m(input_) 220 self.assertEqual(y_prepruning, y_postpruning) 221 222 # with mask of 1s, grad should be identical to no mask 223 y_postpruning.sum().backward() 224 self.assertEqual(old_grad_weight, m.weight_orig.grad) 225 self.assertEqual(old_grad_bias, m.bias.grad) 226 227 # calling forward twice in a row shouldn't change output 228 y1 = m(input_) 229 y2 = m(input_) 230 self.assertEqual(y1, y2) 231 232 def test_random_pruning(self): 233 input_ = torch.ones(1, 5) 234 m = nn.Linear(5, 2) 235 236 # define custom mask to assign with mock 237 mask = torch.ones_like(m.weight) 238 mask[1, 0] = 0 239 mask[0, 3] = 0 240 241 # check grad is zero for masked weights 242 with mock.patch( 243 "torch.nn.utils.prune.RandomUnstructured.compute_mask" 244 ) as compute_mask: 245 compute_mask.return_value = mask 246 prune.random_unstructured(m, name="weight", amount=0.9) 247 248 y_postpruning = m(input_) 249 y_postpruning.sum().backward() 250 # weight_orig is the parameter, so it's the tensor that will accumulate the grad 251 self.assertEqual(m.weight_orig.grad, mask) # all 1s, except for masked units 252 self.assertEqual(m.bias.grad, torch.ones_like(m.bias)) 253 254 # make sure that weight_orig update doesn't modify [1, 0] and [0, 3] 255 old_weight_orig = m.weight_orig.clone() 256 # update weights 257 learning_rate = 1.0 258 for p in m.parameters(): 259 p.data.sub_(p.grad.data * learning_rate) 260 # since these are pruned, they should not be updated 261 self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0]) 262 self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3]) 263 264 def test_random_pruning_forward(self): 265 r"""check forward with mask (by hand).""" 266 input_ = torch.ones(1, 5) 267 m = nn.Linear(5, 2) 268 269 # define custom mask to assign with mock 270 mask = torch.zeros_like(m.weight) 271 mask[1, 0] = 1 272 mask[0, 3] = 1 273 274 with mock.patch( 275 "torch.nn.utils.prune.RandomUnstructured.compute_mask" 276 ) as compute_mask: 277 compute_mask.return_value = mask 278 prune.random_unstructured(m, name="weight", amount=0.9) 279 280 yhat = m(input_) 281 self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0]) 282 self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1]) 283 284 def test_remove_pruning_forward(self): 285 r"""Remove pruning and check forward is unchanged from previous 286 pruned state. 287 """ 288 input_ = torch.ones(1, 5) 289 m = nn.Linear(5, 2) 290 291 # define custom mask to assign with mock 292 mask = torch.ones_like(m.weight) 293 mask[1, 0] = 0 294 mask[0, 3] = 0 295 296 # check grad is zero for masked weights 297 with mock.patch( 298 "torch.nn.utils.prune.RandomUnstructured.compute_mask" 299 ) as compute_mask: 300 compute_mask.return_value = mask 301 prune.random_unstructured(m, name="weight", amount=0.9) 302 303 y_postpruning = m(input_) 304 305 prune.remove(m, "weight") 306 307 y_postremoval = m(input_) 308 self.assertEqual(y_postpruning, y_postremoval) 309 310 def test_pruning_id_consistency(self): 311 r"""Test that pruning doesn't change the id of the parameters, which 312 would otherwise introduce issues with pre-existing optimizers that 313 point to old parameters. 314 """ 315 m = nn.Linear(5, 2, bias=False) 316 317 tensor_id = id(next(iter(m.parameters()))) 318 319 prune.random_unstructured(m, name="weight", amount=0.9) 320 self.assertEqual(tensor_id, id(next(iter(m.parameters())))) 321 322 prune.remove(m, "weight") 323 self.assertEqual(tensor_id, id(next(iter(m.parameters())))) 324 325 def test_random_pruning_pickle(self): 326 modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 327 names = ["weight", "bias"] 328 329 for m in modules: 330 for name in names: 331 with self.subTest(m=m, name=name): 332 prune.random_unstructured(m, name=name, amount=0.1) 333 m_new = pickle.loads(pickle.dumps(m)) 334 self.assertIsInstance(m_new, type(m)) 335 336 def test_multiple_pruning_calls(self): 337 # if you call pruning twice, the hook becomes a PruningContainer 338 m = nn.Conv3d(2, 2, 2) 339 prune.l1_unstructured(m, name="weight", amount=0.1) 340 weight_mask0 = m.weight_mask # save it for later sanity check 341 342 # prune again 343 prune.ln_structured(m, name="weight", amount=0.3, n=2, dim=0) 344 hook = next(iter(m._forward_pre_hooks.values())) 345 self.assertIsInstance(hook, torch.nn.utils.prune.PruningContainer) 346 # check that container._tensor_name is correctly set no matter how 347 # many pruning methods are in the container 348 self.assertEqual(hook._tensor_name, "weight") 349 350 # check that the pruning container has the right length 351 # equal to the number of pruning iters 352 self.assertEqual(len(hook), 2) # m.weight has been pruned twice 353 354 # check that the entries of the pruning container are of the expected 355 # type and in the expected order 356 self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured) 357 self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured) 358 359 # check that all entries that are 0 in the 1st mask are 0 in the 360 # 2nd mask too 361 self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0)) 362 363 # prune again 364 prune.ln_structured(m, name="weight", amount=0.1, n=float("inf"), dim=1) 365 # check that container._tensor_name is correctly set no matter how 366 # many pruning methods are in the container 367 hook = next(iter(m._forward_pre_hooks.values())) 368 self.assertEqual(hook._tensor_name, "weight") 369 370 def test_pruning_container(self): 371 # create an empty container 372 container = prune.PruningContainer() 373 container._tensor_name = "test" 374 self.assertEqual(len(container), 0) 375 376 p = prune.L1Unstructured(amount=2) 377 p._tensor_name = "test" 378 379 # test adding a pruning method to a container 380 container.add_pruning_method(p) 381 382 # test error raised if tensor name is different 383 q = prune.L1Unstructured(amount=2) 384 q._tensor_name = "another_test" 385 with self.assertRaises(ValueError): 386 container.add_pruning_method(q) 387 388 # test that adding a non-pruning method object to a pruning container 389 # raises a TypeError 390 with self.assertRaises(TypeError): 391 container.add_pruning_method(10) 392 with self.assertRaises(TypeError): 393 container.add_pruning_method("ugh") 394 395 def test_pruning_container_compute_mask(self): 396 r"""Test `compute_mask` of pruning container with a known `t` and 397 `default_mask`. Indirectly checks that Ln structured pruning is 398 acting on the right axis. 399 """ 400 # create an empty container 401 container = prune.PruningContainer() 402 container._tensor_name = "test" 403 404 # 1) test unstructured pruning 405 # create a new pruning method 406 p = prune.L1Unstructured(amount=2) 407 p._tensor_name = "test" 408 # add the pruning method to the container 409 container.add_pruning_method(p) 410 411 # create tensor to be pruned 412 t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 413 # create prior mask by hand 414 default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 415 # since we are pruning the two lowest magnitude units, the outcome of 416 # the calculation should be this: 417 expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32) 418 computed_mask = container.compute_mask(t, default_mask) 419 self.assertEqual(expected_mask, computed_mask) 420 421 # 2) test structured pruning 422 q = prune.LnStructured(amount=1, n=2, dim=0) 423 q._tensor_name = "test" 424 container.add_pruning_method(q) 425 # since we are pruning the lowest magnitude one of the two rows, the 426 # outcome of the calculation should be this: 427 expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32) 428 computed_mask = container.compute_mask(t, default_mask) 429 self.assertEqual(expected_mask, computed_mask) 430 431 # 2) test structured pruning, along another axis 432 r = prune.LnStructured(amount=1, n=2, dim=1) 433 r._tensor_name = "test" 434 container.add_pruning_method(r) 435 # since we are pruning the lowest magnitude of the four columns, the 436 # outcome of the calculation should be this: 437 expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32) 438 computed_mask = container.compute_mask(t, default_mask) 439 self.assertEqual(expected_mask, computed_mask) 440 441 def test_l1_unstructured_pruning(self): 442 r"""Test that l1 unstructured pruning actually removes the lowest 443 entries by l1 norm (by hand). It also checks that applying l1 444 unstructured pruning more than once respects the previous mask. 445 """ 446 m = nn.Linear(4, 2) 447 # modify its weight matrix by hand 448 m.weight = torch.nn.Parameter( 449 torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32) 450 ) 451 452 prune.l1_unstructured(m, "weight", amount=2) 453 expected_weight = torch.tensor( 454 [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype 455 ) 456 self.assertEqual(expected_weight, m.weight) 457 458 # check that pruning again removes the next two smallest entries 459 prune.l1_unstructured(m, "weight", amount=2) 460 expected_weight = torch.tensor( 461 [[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype 462 ) 463 self.assertEqual(expected_weight, m.weight) 464 465 def test_l1_unstructured_pruning_with_importance_scores(self): 466 r"""Test that l1 unstructured pruning actually removes the lowest 467 entries of importance scores and not the parameter by l1 norm (by hand). 468 It also checks that applying l1 unstructured pruning more than once 469 respects the previous mask. 470 """ 471 m = nn.Linear(4, 2) 472 # modify its weight matrix by hand 473 m.weight = torch.nn.Parameter( 474 torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32) 475 ) 476 importance_scores = torch.tensor( 477 [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 478 ) 479 480 prune.l1_unstructured( 481 m, "weight", amount=2, importance_scores=importance_scores 482 ) 483 expected_weight = torch.tensor( 484 [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype 485 ) 486 self.assertEqual(expected_weight, m.weight) 487 488 # check that pruning again removes two entries of m.weight that are colocated with 489 # the next two smallest absolute values of importance scores. 490 prune.l1_unstructured( 491 m, "weight", amount=2, importance_scores=importance_scores 492 ) 493 expected_weight = torch.tensor( 494 [[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype 495 ) 496 self.assertEqual(expected_weight, m.weight) 497 498 def test_unstructured_pruning_same_magnitude(self): 499 r"""Since it may happen that the tensor to prune has entries with the 500 same exact magnitude, it is important to check that pruning happens 501 consistenly based on the bottom % of weights, and not by threshold, 502 which would instead kill off *all* units with magnitude = threshold. 503 """ 504 AMOUNT = 0.2 505 p = prune.L1Unstructured(amount=AMOUNT) 506 # create a random tensors with entries in {-2, 0, 2} 507 t = 2 * torch.randint(low=-1, high=2, size=(10, 7)) 508 nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement()) 509 510 computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) 511 nparams_pruned = torch.sum(computed_mask == 0) 512 self.assertEqual(nparams_toprune, nparams_pruned) 513 514 def test_random_structured_pruning_amount(self): 515 AMOUNT = 0.6 516 AXIS = 2 517 p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) 518 t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=torch.float32) 519 nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS]) 520 521 computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) 522 # check that 1 column is fully prune, the others are left untouched 523 remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS] 524 per_column_sums = sorted(torch.sum(computed_mask == 0, axis=remaining_axes)) 525 assert per_column_sums == [0, 20] 526 527 def test_ln_structured_pruning(self): 528 r"""Check Ln structured pruning by hand.""" 529 m = nn.Conv2d(3, 1, 2) 530 m.weight.data = torch.tensor( 531 [ 532 [ 533 [[1.0, 2.0], [1.0, 2.5]], 534 [[0.5, 1.0], [0.1, 0.1]], 535 [[-3.0, -5.0], [0.1, -1.0]], 536 ] 537 ] 538 ) 539 # expected effect of pruning 1 of the 3 channels by L2-norm 540 expected_mask_axis1 = torch.ones_like(m.weight) 541 expected_mask_axis1[:, 1] = 0.0 542 543 prune.ln_structured(m, "weight", amount=1, n=2, dim=1) 544 self.assertEqual(expected_mask_axis1, m.weight_mask) 545 546 # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm 547 expected_mask_axis3 = expected_mask_axis1 548 expected_mask_axis3[:, :, :, 0] = 0.0 549 550 prune.ln_structured(m, "weight", amount=1, n=1, dim=-1) 551 self.assertEqual(expected_mask_axis3, m.weight_mask) 552 553 def test_ln_structured_pruning_importance_scores(self): 554 r"""Check Ln structured pruning by hand.""" 555 m = nn.Conv2d(3, 1, 2) 556 m.weight.data = torch.tensor( 557 [ 558 [ 559 [[1.0, 2.0], [1.0, 2.5]], 560 [[0.5, 1.0], [0.1, 0.1]], 561 [[-3.0, -5.0], [0.1, -1.0]], 562 ] 563 ] 564 ) 565 importance_scores = torch.tensor( 566 [ 567 [ 568 [[10.0, 1.0], [10.0, 1.0]], 569 [[30.0, 3.0], [30.0, 3.0]], 570 [[-20.0, -2.0], [-20.0, -2.0]], 571 ] 572 ] 573 ) 574 # expected effect of pruning 1 of the 3 channels by L2-norm 575 expected_mask_axis1 = torch.ones_like(m.weight) 576 expected_mask_axis1[:, 0] = 0.0 577 578 prune.ln_structured( 579 m, "weight", amount=1, n=2, dim=1, importance_scores=importance_scores 580 ) 581 self.assertEqual(expected_mask_axis1, m.weight_mask) 582 583 # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm 584 expected_mask_axis3 = expected_mask_axis1 585 expected_mask_axis3[:, :, :, 1] = 0.0 586 587 prune.ln_structured( 588 m, "weight", amount=1, n=1, dim=-1, importance_scores=importance_scores 589 ) 590 self.assertEqual(expected_mask_axis3, m.weight_mask) 591 592 def test_remove_pruning(self): 593 r"""`prune.remove` removes the hook and the reparametrization 594 and makes the pruning final in the original parameter. 595 """ 596 modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 597 names = ["weight", "bias"] 598 599 for m in modules: 600 for name in names: 601 with self.subTest(m=m, name=name): 602 # first prune 603 prune.random_unstructured(m, name, amount=0.5) 604 self.assertIn(name + "_orig", dict(m.named_parameters())) 605 self.assertIn(name + "_mask", dict(m.named_buffers())) 606 self.assertNotIn(name, dict(m.named_parameters())) 607 self.assertTrue(hasattr(m, name)) 608 pruned_t = getattr(m, name) 609 610 # then remove pruning 611 prune.remove(m, name) 612 self.assertIn(name, dict(m.named_parameters())) 613 self.assertNotIn(name + "_orig", dict(m.named_parameters())) 614 self.assertNotIn(name + "_mask", dict(m.named_buffers())) 615 final_t = getattr(m, name) 616 617 self.assertEqual(pruned_t, final_t) 618 619 def test_remove_pruning_exception(self): 620 r"""Removing from an unpruned tensor throws an assertion error""" 621 modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 622 names = ["weight", "bias"] 623 624 for m in modules: 625 for name in names: 626 with self.subTest(m=m, name=name): 627 # check that the module isn't pruned 628 self.assertFalse(prune.is_pruned(m)) 629 # since it isn't pruned, pruning can't be removed from it 630 with self.assertRaises(ValueError): 631 prune.remove(m, name) 632 633 def test_global_pruning(self): 634 r"""Test that global l1 unstructured pruning over 2 parameters removes 635 the `amount=4` smallest global weights across the 2 parameters. 636 """ 637 m = nn.Linear(4, 2) 638 n = nn.Linear(3, 1) 639 # modify the weight matrices by hand 640 m.weight = torch.nn.Parameter( 641 torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32) 642 ) 643 n.weight = torch.nn.Parameter( 644 torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32) 645 ) 646 647 params_to_prune = ( 648 (m, "weight"), 649 (n, "weight"), 650 ) 651 652 # prune the 4 smallest weights globally by L1 magnitude 653 prune.global_unstructured( 654 params_to_prune, pruning_method=prune.L1Unstructured, amount=4 655 ) 656 657 expected_mweight = torch.tensor( 658 [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype 659 ) 660 self.assertEqual(expected_mweight, m.weight) 661 662 expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) 663 self.assertEqual(expected_nweight, n.weight) 664 665 def test_global_pruning_importance_scores(self): 666 r"""Test that global l1 unstructured pruning over 2 parameters removes 667 the `amount=4` smallest global weights across the 2 parameters. 668 """ 669 m = nn.Linear(4, 2) 670 n = nn.Linear(3, 1) 671 # modify the weight matrices by hand 672 m.weight = torch.nn.Parameter( 673 torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32) 674 ) 675 m_importance_scores = torch.tensor( 676 [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 677 ) 678 n.weight = torch.nn.Parameter( 679 torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32) 680 ) 681 n_importance_scores = torch.tensor([[0, 10.0, -0.2]]).to(dtype=torch.float32) 682 683 params_to_prune = ( 684 (m, "weight"), 685 (n, "weight"), 686 ) 687 importance_scores = { 688 (m, "weight"): m_importance_scores, 689 (n, "weight"): n_importance_scores, 690 } 691 692 # prune the 4 smallest weights globally by L1 magnitude 693 prune.global_unstructured( 694 params_to_prune, 695 pruning_method=prune.L1Unstructured, 696 amount=4, 697 importance_scores=importance_scores, 698 ) 699 700 expected_m_weight = torch.tensor( 701 [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype 702 ) 703 self.assertEqual(expected_m_weight, m.weight) 704 705 expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) 706 self.assertEqual(expected_n_weight, n.weight) 707 708 def test_custom_from_mask_pruning(self): 709 r"""Test that the CustomFromMask is capable of receiving 710 as input at instantiation time a custom mask, and combining it with 711 the previous default mask to generate the correct final mask. 712 """ 713 # new mask 714 mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]]) 715 # old mask 716 default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]]) 717 718 # some tensor (not actually used) 719 t = torch.rand_like(mask.to(dtype=torch.float32)) 720 721 p = prune.CustomFromMask(mask=mask) 722 723 computed_mask = p.compute_mask(t, default_mask) 724 expected_mask = torch.tensor( 725 [[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype 726 ) 727 728 self.assertEqual(computed_mask, expected_mask) 729 730 def test_pruning_rollback(self): 731 r"""Test that if something fails when the we try to compute the mask, 732 then the model isn't left in some intermediate half-pruned state. 733 The try/except statement in `apply` should handle rolling back 734 to the previous state before pruning began. 735 """ 736 modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 737 names = ["weight", "bias"] 738 739 for m in modules: 740 for name in names: 741 with self.subTest(m=m, name=name): 742 with mock.patch( 743 "torch.nn.utils.prune.L1Unstructured.compute_mask" 744 ) as compute_mask: 745 compute_mask.side_effect = Exception("HA!") 746 with self.assertRaises(Exception): 747 prune.l1_unstructured(m, name=name, amount=0.9) 748 749 self.assertTrue(name in dict(m.named_parameters())) 750 self.assertFalse(name + "_mask" in dict(m.named_buffers())) 751 self.assertFalse(name + "_orig" in dict(m.named_parameters())) 752 753 def test_pruning_serialization_model(self): 754 # create a model 755 model = torch.nn.Sequential( 756 torch.nn.Linear(10, 10), 757 torch.nn.ReLU(), 758 torch.nn.Linear(10, 1), 759 ) 760 # check that everything looks normal before pruning 761 self.assertNotIn("0.weight_orig", model.state_dict()) 762 self.assertNotIn("0.weight_mask", model.state_dict()) 763 self.assertIn("0.weight", model.state_dict()) 764 765 # prune one of its parameters 766 prune.l1_unstructured(module=model[0], name="weight", amount=0.9) 767 768 # check that the original weight and the new mask are present 769 self.assertIn("0.weight_orig", model.state_dict()) 770 self.assertIn("0.weight_mask", model.state_dict()) 771 self.assertNotIn("0.weight", model.state_dict()) 772 self.assertTrue(hasattr(model[0], "weight")) 773 774 pruned_weight = model[0].weight 775 776 with TemporaryFileName() as fname: 777 torch.save(model, fname) 778 # weights_only=False as this is legacy code that saves the model 779 new_model = torch.load(fname, weights_only=False) 780 781 # check that the original weight and the new mask are present 782 self.assertIn("0.weight_orig", new_model.state_dict()) 783 self.assertIn("0.weight_mask", new_model.state_dict()) 784 self.assertNotIn("0.weight", new_model.state_dict()) 785 self.assertTrue(hasattr(new_model[0], "weight")) 786 787 self.assertEqual(pruned_weight, new_model[0].weight) 788 789 def test_pruning_serialization_state_dict(self): 790 # create a model 791 model = torch.nn.Sequential( 792 torch.nn.Linear(10, 10), 793 torch.nn.ReLU(), 794 torch.nn.Linear(10, 1), 795 ) 796 # check that everything looks normal before pruning 797 self.assertNotIn("0.weight_orig", model.state_dict()) 798 self.assertNotIn("0.weight_mask", model.state_dict()) 799 self.assertIn("0.weight", model.state_dict()) 800 801 # prune one of its parameters 802 prune.l1_unstructured(module=model[0], name="weight", amount=0.9) 803 804 # check that the original weight and the new mask are present 805 self.assertIn("0.weight_orig", model.state_dict()) 806 self.assertIn("0.weight_mask", model.state_dict()) 807 self.assertNotIn("0.weight", model.state_dict()) 808 self.assertTrue(hasattr(model[0], "weight")) 809 810 pruned_weight = model[0].weight 811 812 # make pruning permanent and restore parameter names as in base 813 # architecture 814 prune.remove(module=model[0], name="weight") 815 816 # check that the original weight and the new mask are no longer present 817 self.assertNotIn("0.weight_orig", model.state_dict()) 818 self.assertNotIn("0.weight_mask", model.state_dict()) 819 self.assertIn("0.weight", model.state_dict()) 820 821 # save the state dict of model and reload it into new_model 822 new_model = torch.nn.Sequential( 823 torch.nn.Linear(10, 10), 824 torch.nn.ReLU(), 825 torch.nn.Linear(10, 1), 826 ) 827 with TemporaryFileName() as fname: 828 torch.save(model.state_dict(), fname) 829 new_model.load_state_dict(torch.load(fname)) 830 831 # check that the original weight and the new mask are not present in 832 # new_model either. 833 self.assertNotIn("0.weight_orig", new_model.state_dict()) 834 self.assertNotIn("0.weight_mask", new_model.state_dict()) 835 self.assertIn("0.weight", new_model.state_dict()) 836 837 self.assertEqual(pruned_weight, new_model[0].weight) 838 839 def test_prune(self): 840 # create a new pruning method 841 p = prune.L1Unstructured(amount=2) 842 # create tensor to be pruned 843 t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 844 # create prior mask by hand 845 default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 846 # since we are pruning the two lowest magnitude units, the outcome of 847 # the calculation should be this: 848 expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) 849 pruned_tensor = p.prune(t, default_mask) 850 self.assertEqual(t * expected_mask, pruned_tensor) 851 852 def test_prune_importance_scores(self): 853 # create a new pruning method 854 p = prune.L1Unstructured(amount=2) 855 # create tensor to be pruned 856 t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 857 importance_scores = torch.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to( 858 dtype=torch.float32 859 ) 860 # create prior mask by hand 861 default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 862 # since we are pruning the two lowest magnitude units, the outcome of 863 # the calculation should be this: 864 expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) 865 pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) 866 self.assertEqual(t * expected_mask, pruned_tensor) 867 868 def test_prune_importance_scores_mimic_default(self): 869 # create a new pruning method 870 p = prune.L1Unstructured(amount=2) 871 # create tensor to be pruned 872 t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 873 # create prior mask by hand 874 default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 875 # since we are pruning the two lowest magnitude units, the outcome of 876 # the calculation should be this: 877 expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) 878 pruned_tensor_without_importance_scores = p.prune(t, default_mask) 879 pruned_tensor_with_importance_scores = p.prune( 880 t, default_mask, importance_scores=t 881 ) 882 self.assertEqual( 883 pruned_tensor_without_importance_scores, 884 pruned_tensor_with_importance_scores, 885 ) 886 self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) 887 888 def test_rnn_pruning(self): 889 l = torch.nn.LSTM(32, 32) 890 # This Module has 4 parameters called: 891 # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' 892 893 # Pruning one of them causes one of the weights to become a tensor 894 prune.l1_unstructured(l, "weight_ih_l0", 0.5) 895 assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 3 896 897 # Removing the pruning reparametrization restores the Parameter 898 prune.remove(l, "weight_ih_l0") 899 assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 4 900 901 # Make sure that, upon removal of the reparametrization, the 902 # `._parameters` and `.named_parameters` contain the right params. 903 # Specifically, the original weight ('weight_ih_l0') should be placed 904 # back in the parameters, while the reparametrization component 905 # ('weight_ih_l0_orig') should be removed. 906 assert "weight_ih_l0" in l._parameters 907 assert l._parameters["weight_ih_l0"] is not None 908 assert "weight_ih_l0_orig" not in l._parameters 909 assert "weight_ih_l0" in dict(l.named_parameters()) 910 assert dict(l.named_parameters())["weight_ih_l0"] is not None 911 assert "weight_ih_l0_orig" not in dict(l.named_parameters()) 912 913 914instantiate_parametrized_tests(TestPruningNN) 915 916if __name__ == "__main__": 917 run_tests() 918