1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Workerimport pickle 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerimport unittest.mock as mock 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.prune as prune 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 11*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 12*da0073e9SAndroid Build Coastguard Worker run_tests, 13*da0073e9SAndroid Build Coastguard Worker TemporaryFileName, 14*da0073e9SAndroid Build Coastguard Worker TEST_NUMPY, 15*da0073e9SAndroid Build Coastguard Worker) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerclass TestPruningNN(NNTestCase): 19*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 20*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker # torch/nn/utils/prune.py 23*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy not found") 24*da0073e9SAndroid Build Coastguard Worker def test_validate_pruning_amount_init(self): 25*da0073e9SAndroid Build Coastguard Worker r"""Test the first util function that validates the pruning 26*da0073e9SAndroid Build Coastguard Worker amount requested by the user the moment the pruning method 27*da0073e9SAndroid Build Coastguard Worker is initialized. This test checks that the expected errors are 28*da0073e9SAndroid Build Coastguard Worker raised whenever the amount is invalid. 29*da0073e9SAndroid Build Coastguard Worker The original function runs basic type checking + value range checks. 30*da0073e9SAndroid Build Coastguard Worker It doesn't check the validity of the pruning amount with 31*da0073e9SAndroid Build Coastguard Worker respect to the size of the tensor to prune. That's left to 32*da0073e9SAndroid Build Coastguard Worker `_validate_pruning_amount`, tested below. 33*da0073e9SAndroid Build Coastguard Worker """ 34*da0073e9SAndroid Build Coastguard Worker # neither float not int should raise TypeError 35*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 36*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount="I'm a string") 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker # float not in [0, 1] should raise ValueError 39*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 40*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=1.1) 41*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 42*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=20.0) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker # negative int should raise ValueError 45*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 46*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=-10) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker # all these should pass without errors because they're valid amounts 49*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=0.34) 50*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=1500) 51*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=0) 52*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=0.0) 53*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=1) 54*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount_init(amount=1.0) 55*da0073e9SAndroid Build Coastguard Worker self.assertTrue(True) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy not found") 58*da0073e9SAndroid Build Coastguard Worker def test_validate_pruning_amount(self): 59*da0073e9SAndroid Build Coastguard Worker r"""Tests the second util function that validates the pruning 60*da0073e9SAndroid Build Coastguard Worker amount requested by the user, this time with respect to the size 61*da0073e9SAndroid Build Coastguard Worker of the tensor to prune. The rationale is that if the pruning amount, 62*da0073e9SAndroid Build Coastguard Worker converted to absolute value of units to prune, is larger than 63*da0073e9SAndroid Build Coastguard Worker the number of units in the tensor, then we expect the util function 64*da0073e9SAndroid Build Coastguard Worker to raise a value error. 65*da0073e9SAndroid Build Coastguard Worker """ 66*da0073e9SAndroid Build Coastguard Worker # if amount is int and amount > tensor_size, raise ValueError 67*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 68*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount(amount=20, tensor_size=19) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker # amount is a float so this should not raise an error 71*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount(amount=0.3, tensor_size=0) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker # this is okay 74*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount(amount=19, tensor_size=20) 75*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount(amount=0, tensor_size=0) 76*da0073e9SAndroid Build Coastguard Worker prune._validate_pruning_amount(amount=1, tensor_size=1) 77*da0073e9SAndroid Build Coastguard Worker self.assertTrue(True) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy not found") 80*da0073e9SAndroid Build Coastguard Worker def test_compute_nparams_to_prune(self): 81*da0073e9SAndroid Build Coastguard Worker r"""Test that requested pruning `amount` gets translated into the 82*da0073e9SAndroid Build Coastguard Worker correct absolute number of units to prune. 83*da0073e9SAndroid Build Coastguard Worker """ 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0) 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10) 86*da0073e9SAndroid Build Coastguard Worker # if 1 is int, means 1 unit 87*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1) 88*da0073e9SAndroid Build Coastguard Worker # if 1. is float, means 100% of units 89*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15) 90*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def test_random_pruning_sizes(self): 93*da0073e9SAndroid Build Coastguard Worker r"""Test that the new parameters and buffers created by the pruning 94*da0073e9SAndroid Build Coastguard Worker method have the same size as the input tensor to prune. These, in 95*da0073e9SAndroid Build Coastguard Worker fact, correspond to the pruned version of the tensor itself, its 96*da0073e9SAndroid Build Coastguard Worker mask, and its original copy, so the size must match. 97*da0073e9SAndroid Build Coastguard Worker """ 98*da0073e9SAndroid Build Coastguard Worker # fixturize test 99*da0073e9SAndroid Build Coastguard Worker # TODO: add other modules 100*da0073e9SAndroid Build Coastguard Worker modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 101*da0073e9SAndroid Build Coastguard Worker names = ["weight", "bias"] 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker for m in modules: 104*da0073e9SAndroid Build Coastguard Worker for name in names: 105*da0073e9SAndroid Build Coastguard Worker with self.subTest(m=m, name=name): 106*da0073e9SAndroid Build Coastguard Worker original_tensor = getattr(m, name) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name=name, amount=0.1) 109*da0073e9SAndroid Build Coastguard Worker # mask has the same size as tensor being pruned 110*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 111*da0073e9SAndroid Build Coastguard Worker original_tensor.size(), getattr(m, name + "_mask").size() 112*da0073e9SAndroid Build Coastguard Worker ) 113*da0073e9SAndroid Build Coastguard Worker # 'orig' tensor has the same size as the original tensor 114*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 115*da0073e9SAndroid Build Coastguard Worker original_tensor.size(), getattr(m, name + "_orig").size() 116*da0073e9SAndroid Build Coastguard Worker ) 117*da0073e9SAndroid Build Coastguard Worker # new tensor has the same size as the original tensor 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(original_tensor.size(), getattr(m, name).size()) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def test_random_pruning_orig(self): 121*da0073e9SAndroid Build Coastguard Worker r"""Test that original tensor is correctly stored in 'orig' 122*da0073e9SAndroid Build Coastguard Worker after pruning is applied. Important to make sure we don't 123*da0073e9SAndroid Build Coastguard Worker lose info about the original unpruned parameter. 124*da0073e9SAndroid Build Coastguard Worker """ 125*da0073e9SAndroid Build Coastguard Worker # fixturize test 126*da0073e9SAndroid Build Coastguard Worker # TODO: add other modules 127*da0073e9SAndroid Build Coastguard Worker modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 128*da0073e9SAndroid Build Coastguard Worker names = ["weight", "bias"] 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker for m in modules: 131*da0073e9SAndroid Build Coastguard Worker for name in names: 132*da0073e9SAndroid Build Coastguard Worker with self.subTest(m=m, name=name): 133*da0073e9SAndroid Build Coastguard Worker # tensor prior to pruning 134*da0073e9SAndroid Build Coastguard Worker original_tensor = getattr(m, name) 135*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name=name, amount=0.1) 136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(original_tensor, getattr(m, name + "_orig")) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker def test_random_pruning_new_weight(self): 139*da0073e9SAndroid Build Coastguard Worker r"""Test that module.name now contains a pruned version of 140*da0073e9SAndroid Build Coastguard Worker the original tensor obtained from multiplying it by the mask. 141*da0073e9SAndroid Build Coastguard Worker """ 142*da0073e9SAndroid Build Coastguard Worker # fixturize test 143*da0073e9SAndroid Build Coastguard Worker # TODO: add other modules 144*da0073e9SAndroid Build Coastguard Worker modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 145*da0073e9SAndroid Build Coastguard Worker names = ["weight", "bias"] 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker for m in modules: 148*da0073e9SAndroid Build Coastguard Worker for name in names: 149*da0073e9SAndroid Build Coastguard Worker with self.subTest(m=m, name=name): 150*da0073e9SAndroid Build Coastguard Worker # tensor prior to pruning 151*da0073e9SAndroid Build Coastguard Worker original_tensor = getattr(m, name) 152*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name=name, amount=0.1) 153*da0073e9SAndroid Build Coastguard Worker # weight = weight_orig * weight_mask 154*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 155*da0073e9SAndroid Build Coastguard Worker getattr(m, name), 156*da0073e9SAndroid Build Coastguard Worker getattr(m, name + "_orig") 157*da0073e9SAndroid Build Coastguard Worker * getattr(m, name + "_mask").to(dtype=original_tensor.dtype), 158*da0073e9SAndroid Build Coastguard Worker ) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker def test_identity_pruning(self): 161*da0073e9SAndroid Build Coastguard Worker r"""Test that a mask of 1s does not change forward or backward.""" 162*da0073e9SAndroid Build Coastguard Worker input_ = torch.ones(1, 5) 163*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 2) 164*da0073e9SAndroid Build Coastguard Worker y_prepruning = m(input_) # output prior to pruning 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker # compute grad pre-pruning and check it's equal to all ones 167*da0073e9SAndroid Build Coastguard Worker y_prepruning.sum().backward() 168*da0073e9SAndroid Build Coastguard Worker old_grad_weight = m.weight.grad.clone() # don't grab pointer! 169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) 170*da0073e9SAndroid Build Coastguard Worker old_grad_bias = m.bias.grad.clone() 171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker # remove grads 174*da0073e9SAndroid Build Coastguard Worker m.zero_grad() 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker # force the mask to be made of all 1s 177*da0073e9SAndroid Build Coastguard Worker prune.identity(m, name="weight") 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker # with mask of 1s, output should be identical to no mask 180*da0073e9SAndroid Build Coastguard Worker y_postpruning = m(input_) 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_prepruning, y_postpruning) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker # with mask of 1s, grad should be identical to no mask 184*da0073e9SAndroid Build Coastguard Worker y_postpruning.sum().backward() 185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_weight, m.weight_orig.grad) 186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_bias, m.bias.grad) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker # calling forward twice in a row shouldn't change output 189*da0073e9SAndroid Build Coastguard Worker y1 = m(input_) 190*da0073e9SAndroid Build Coastguard Worker y2 = m(input_) 191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def test_random_pruning_0perc(self): 194*da0073e9SAndroid Build Coastguard Worker r"""Test that a mask of 1s does not change forward or backward.""" 195*da0073e9SAndroid Build Coastguard Worker input_ = torch.ones(1, 5) 196*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 2) 197*da0073e9SAndroid Build Coastguard Worker y_prepruning = m(input_) # output prior to pruning 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker # compute grad pre-pruning and check it's equal to all ones 200*da0073e9SAndroid Build Coastguard Worker y_prepruning.sum().backward() 201*da0073e9SAndroid Build Coastguard Worker old_grad_weight = m.weight.grad.clone() # don't grab pointer! 202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) 203*da0073e9SAndroid Build Coastguard Worker old_grad_bias = m.bias.grad.clone() 204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker # remove grads 207*da0073e9SAndroid Build Coastguard Worker m.zero_grad() 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker # force the mask to be made of all 1s 210*da0073e9SAndroid Build Coastguard Worker with mock.patch( 211*da0073e9SAndroid Build Coastguard Worker "torch.nn.utils.prune.RandomUnstructured.compute_mask" 212*da0073e9SAndroid Build Coastguard Worker ) as compute_mask: 213*da0073e9SAndroid Build Coastguard Worker compute_mask.return_value = torch.ones_like(m.weight) 214*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured( 215*da0073e9SAndroid Build Coastguard Worker m, name="weight", amount=0.9 216*da0073e9SAndroid Build Coastguard Worker ) # amount won't count 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker # with mask of 1s, output should be identical to no mask 219*da0073e9SAndroid Build Coastguard Worker y_postpruning = m(input_) 220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_prepruning, y_postpruning) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker # with mask of 1s, grad should be identical to no mask 223*da0073e9SAndroid Build Coastguard Worker y_postpruning.sum().backward() 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_weight, m.weight_orig.grad) 225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_grad_bias, m.bias.grad) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker # calling forward twice in a row shouldn't change output 228*da0073e9SAndroid Build Coastguard Worker y1 = m(input_) 229*da0073e9SAndroid Build Coastguard Worker y2 = m(input_) 230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def test_random_pruning(self): 233*da0073e9SAndroid Build Coastguard Worker input_ = torch.ones(1, 5) 234*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 2) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker # define custom mask to assign with mock 237*da0073e9SAndroid Build Coastguard Worker mask = torch.ones_like(m.weight) 238*da0073e9SAndroid Build Coastguard Worker mask[1, 0] = 0 239*da0073e9SAndroid Build Coastguard Worker mask[0, 3] = 0 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker # check grad is zero for masked weights 242*da0073e9SAndroid Build Coastguard Worker with mock.patch( 243*da0073e9SAndroid Build Coastguard Worker "torch.nn.utils.prune.RandomUnstructured.compute_mask" 244*da0073e9SAndroid Build Coastguard Worker ) as compute_mask: 245*da0073e9SAndroid Build Coastguard Worker compute_mask.return_value = mask 246*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name="weight", amount=0.9) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker y_postpruning = m(input_) 249*da0073e9SAndroid Build Coastguard Worker y_postpruning.sum().backward() 250*da0073e9SAndroid Build Coastguard Worker # weight_orig is the parameter, so it's the tensor that will accumulate the grad 251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_orig.grad, mask) # all 1s, except for masked units 252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.bias.grad, torch.ones_like(m.bias)) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker # make sure that weight_orig update doesn't modify [1, 0] and [0, 3] 255*da0073e9SAndroid Build Coastguard Worker old_weight_orig = m.weight_orig.clone() 256*da0073e9SAndroid Build Coastguard Worker # update weights 257*da0073e9SAndroid Build Coastguard Worker learning_rate = 1.0 258*da0073e9SAndroid Build Coastguard Worker for p in m.parameters(): 259*da0073e9SAndroid Build Coastguard Worker p.data.sub_(p.grad.data * learning_rate) 260*da0073e9SAndroid Build Coastguard Worker # since these are pruned, they should not be updated 261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0]) 262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3]) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def test_random_pruning_forward(self): 265*da0073e9SAndroid Build Coastguard Worker r"""check forward with mask (by hand).""" 266*da0073e9SAndroid Build Coastguard Worker input_ = torch.ones(1, 5) 267*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 2) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker # define custom mask to assign with mock 270*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros_like(m.weight) 271*da0073e9SAndroid Build Coastguard Worker mask[1, 0] = 1 272*da0073e9SAndroid Build Coastguard Worker mask[0, 3] = 1 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker with mock.patch( 275*da0073e9SAndroid Build Coastguard Worker "torch.nn.utils.prune.RandomUnstructured.compute_mask" 276*da0073e9SAndroid Build Coastguard Worker ) as compute_mask: 277*da0073e9SAndroid Build Coastguard Worker compute_mask.return_value = mask 278*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name="weight", amount=0.9) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker yhat = m(input_) 281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0]) 282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1]) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def test_remove_pruning_forward(self): 285*da0073e9SAndroid Build Coastguard Worker r"""Remove pruning and check forward is unchanged from previous 286*da0073e9SAndroid Build Coastguard Worker pruned state. 287*da0073e9SAndroid Build Coastguard Worker """ 288*da0073e9SAndroid Build Coastguard Worker input_ = torch.ones(1, 5) 289*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 2) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker # define custom mask to assign with mock 292*da0073e9SAndroid Build Coastguard Worker mask = torch.ones_like(m.weight) 293*da0073e9SAndroid Build Coastguard Worker mask[1, 0] = 0 294*da0073e9SAndroid Build Coastguard Worker mask[0, 3] = 0 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker # check grad is zero for masked weights 297*da0073e9SAndroid Build Coastguard Worker with mock.patch( 298*da0073e9SAndroid Build Coastguard Worker "torch.nn.utils.prune.RandomUnstructured.compute_mask" 299*da0073e9SAndroid Build Coastguard Worker ) as compute_mask: 300*da0073e9SAndroid Build Coastguard Worker compute_mask.return_value = mask 301*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name="weight", amount=0.9) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker y_postpruning = m(input_) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker prune.remove(m, "weight") 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker y_postremoval = m(input_) 308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_postpruning, y_postremoval) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker def test_pruning_id_consistency(self): 311*da0073e9SAndroid Build Coastguard Worker r"""Test that pruning doesn't change the id of the parameters, which 312*da0073e9SAndroid Build Coastguard Worker would otherwise introduce issues with pre-existing optimizers that 313*da0073e9SAndroid Build Coastguard Worker point to old parameters. 314*da0073e9SAndroid Build Coastguard Worker """ 315*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 2, bias=False) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker tensor_id = id(next(iter(m.parameters()))) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name="weight", amount=0.9) 320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_id, id(next(iter(m.parameters())))) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker prune.remove(m, "weight") 323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_id, id(next(iter(m.parameters())))) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker def test_random_pruning_pickle(self): 326*da0073e9SAndroid Build Coastguard Worker modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 327*da0073e9SAndroid Build Coastguard Worker names = ["weight", "bias"] 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker for m in modules: 330*da0073e9SAndroid Build Coastguard Worker for name in names: 331*da0073e9SAndroid Build Coastguard Worker with self.subTest(m=m, name=name): 332*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name=name, amount=0.1) 333*da0073e9SAndroid Build Coastguard Worker m_new = pickle.loads(pickle.dumps(m)) 334*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(m_new, type(m)) 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker def test_multiple_pruning_calls(self): 337*da0073e9SAndroid Build Coastguard Worker # if you call pruning twice, the hook becomes a PruningContainer 338*da0073e9SAndroid Build Coastguard Worker m = nn.Conv3d(2, 2, 2) 339*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured(m, name="weight", amount=0.1) 340*da0073e9SAndroid Build Coastguard Worker weight_mask0 = m.weight_mask # save it for later sanity check 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # prune again 343*da0073e9SAndroid Build Coastguard Worker prune.ln_structured(m, name="weight", amount=0.3, n=2, dim=0) 344*da0073e9SAndroid Build Coastguard Worker hook = next(iter(m._forward_pre_hooks.values())) 345*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(hook, torch.nn.utils.prune.PruningContainer) 346*da0073e9SAndroid Build Coastguard Worker # check that container._tensor_name is correctly set no matter how 347*da0073e9SAndroid Build Coastguard Worker # many pruning methods are in the container 348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hook._tensor_name, "weight") 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker # check that the pruning container has the right length 351*da0073e9SAndroid Build Coastguard Worker # equal to the number of pruning iters 352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(hook), 2) # m.weight has been pruned twice 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker # check that the entries of the pruning container are of the expected 355*da0073e9SAndroid Build Coastguard Worker # type and in the expected order 356*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured) 357*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker # check that all entries that are 0 in the 1st mask are 0 in the 360*da0073e9SAndroid Build Coastguard Worker # 2nd mask too 361*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0)) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker # prune again 364*da0073e9SAndroid Build Coastguard Worker prune.ln_structured(m, name="weight", amount=0.1, n=float("inf"), dim=1) 365*da0073e9SAndroid Build Coastguard Worker # check that container._tensor_name is correctly set no matter how 366*da0073e9SAndroid Build Coastguard Worker # many pruning methods are in the container 367*da0073e9SAndroid Build Coastguard Worker hook = next(iter(m._forward_pre_hooks.values())) 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hook._tensor_name, "weight") 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker def test_pruning_container(self): 371*da0073e9SAndroid Build Coastguard Worker # create an empty container 372*da0073e9SAndroid Build Coastguard Worker container = prune.PruningContainer() 373*da0073e9SAndroid Build Coastguard Worker container._tensor_name = "test" 374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(container), 0) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker p = prune.L1Unstructured(amount=2) 377*da0073e9SAndroid Build Coastguard Worker p._tensor_name = "test" 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker # test adding a pruning method to a container 380*da0073e9SAndroid Build Coastguard Worker container.add_pruning_method(p) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker # test error raised if tensor name is different 383*da0073e9SAndroid Build Coastguard Worker q = prune.L1Unstructured(amount=2) 384*da0073e9SAndroid Build Coastguard Worker q._tensor_name = "another_test" 385*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 386*da0073e9SAndroid Build Coastguard Worker container.add_pruning_method(q) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker # test that adding a non-pruning method object to a pruning container 389*da0073e9SAndroid Build Coastguard Worker # raises a TypeError 390*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 391*da0073e9SAndroid Build Coastguard Worker container.add_pruning_method(10) 392*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 393*da0073e9SAndroid Build Coastguard Worker container.add_pruning_method("ugh") 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker def test_pruning_container_compute_mask(self): 396*da0073e9SAndroid Build Coastguard Worker r"""Test `compute_mask` of pruning container with a known `t` and 397*da0073e9SAndroid Build Coastguard Worker `default_mask`. Indirectly checks that Ln structured pruning is 398*da0073e9SAndroid Build Coastguard Worker acting on the right axis. 399*da0073e9SAndroid Build Coastguard Worker """ 400*da0073e9SAndroid Build Coastguard Worker # create an empty container 401*da0073e9SAndroid Build Coastguard Worker container = prune.PruningContainer() 402*da0073e9SAndroid Build Coastguard Worker container._tensor_name = "test" 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker # 1) test unstructured pruning 405*da0073e9SAndroid Build Coastguard Worker # create a new pruning method 406*da0073e9SAndroid Build Coastguard Worker p = prune.L1Unstructured(amount=2) 407*da0073e9SAndroid Build Coastguard Worker p._tensor_name = "test" 408*da0073e9SAndroid Build Coastguard Worker # add the pruning method to the container 409*da0073e9SAndroid Build Coastguard Worker container.add_pruning_method(p) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker # create tensor to be pruned 412*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 413*da0073e9SAndroid Build Coastguard Worker # create prior mask by hand 414*da0073e9SAndroid Build Coastguard Worker default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 415*da0073e9SAndroid Build Coastguard Worker # since we are pruning the two lowest magnitude units, the outcome of 416*da0073e9SAndroid Build Coastguard Worker # the calculation should be this: 417*da0073e9SAndroid Build Coastguard Worker expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32) 418*da0073e9SAndroid Build Coastguard Worker computed_mask = container.compute_mask(t, default_mask) 419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mask, computed_mask) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker # 2) test structured pruning 422*da0073e9SAndroid Build Coastguard Worker q = prune.LnStructured(amount=1, n=2, dim=0) 423*da0073e9SAndroid Build Coastguard Worker q._tensor_name = "test" 424*da0073e9SAndroid Build Coastguard Worker container.add_pruning_method(q) 425*da0073e9SAndroid Build Coastguard Worker # since we are pruning the lowest magnitude one of the two rows, the 426*da0073e9SAndroid Build Coastguard Worker # outcome of the calculation should be this: 427*da0073e9SAndroid Build Coastguard Worker expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32) 428*da0073e9SAndroid Build Coastguard Worker computed_mask = container.compute_mask(t, default_mask) 429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mask, computed_mask) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker # 2) test structured pruning, along another axis 432*da0073e9SAndroid Build Coastguard Worker r = prune.LnStructured(amount=1, n=2, dim=1) 433*da0073e9SAndroid Build Coastguard Worker r._tensor_name = "test" 434*da0073e9SAndroid Build Coastguard Worker container.add_pruning_method(r) 435*da0073e9SAndroid Build Coastguard Worker # since we are pruning the lowest magnitude of the four columns, the 436*da0073e9SAndroid Build Coastguard Worker # outcome of the calculation should be this: 437*da0073e9SAndroid Build Coastguard Worker expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32) 438*da0073e9SAndroid Build Coastguard Worker computed_mask = container.compute_mask(t, default_mask) 439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mask, computed_mask) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def test_l1_unstructured_pruning(self): 442*da0073e9SAndroid Build Coastguard Worker r"""Test that l1 unstructured pruning actually removes the lowest 443*da0073e9SAndroid Build Coastguard Worker entries by l1 norm (by hand). It also checks that applying l1 444*da0073e9SAndroid Build Coastguard Worker unstructured pruning more than once respects the previous mask. 445*da0073e9SAndroid Build Coastguard Worker """ 446*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 2) 447*da0073e9SAndroid Build Coastguard Worker # modify its weight matrix by hand 448*da0073e9SAndroid Build Coastguard Worker m.weight = torch.nn.Parameter( 449*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32) 450*da0073e9SAndroid Build Coastguard Worker ) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured(m, "weight", amount=2) 453*da0073e9SAndroid Build Coastguard Worker expected_weight = torch.tensor( 454*da0073e9SAndroid Build Coastguard Worker [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype 455*da0073e9SAndroid Build Coastguard Worker ) 456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_weight, m.weight) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker # check that pruning again removes the next two smallest entries 459*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured(m, "weight", amount=2) 460*da0073e9SAndroid Build Coastguard Worker expected_weight = torch.tensor( 461*da0073e9SAndroid Build Coastguard Worker [[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype 462*da0073e9SAndroid Build Coastguard Worker ) 463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_weight, m.weight) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def test_l1_unstructured_pruning_with_importance_scores(self): 466*da0073e9SAndroid Build Coastguard Worker r"""Test that l1 unstructured pruning actually removes the lowest 467*da0073e9SAndroid Build Coastguard Worker entries of importance scores and not the parameter by l1 norm (by hand). 468*da0073e9SAndroid Build Coastguard Worker It also checks that applying l1 unstructured pruning more than once 469*da0073e9SAndroid Build Coastguard Worker respects the previous mask. 470*da0073e9SAndroid Build Coastguard Worker """ 471*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 2) 472*da0073e9SAndroid Build Coastguard Worker # modify its weight matrix by hand 473*da0073e9SAndroid Build Coastguard Worker m.weight = torch.nn.Parameter( 474*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32) 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker importance_scores = torch.tensor( 477*da0073e9SAndroid Build Coastguard Worker [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 478*da0073e9SAndroid Build Coastguard Worker ) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured( 481*da0073e9SAndroid Build Coastguard Worker m, "weight", amount=2, importance_scores=importance_scores 482*da0073e9SAndroid Build Coastguard Worker ) 483*da0073e9SAndroid Build Coastguard Worker expected_weight = torch.tensor( 484*da0073e9SAndroid Build Coastguard Worker [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_weight, m.weight) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker # check that pruning again removes two entries of m.weight that are colocated with 489*da0073e9SAndroid Build Coastguard Worker # the next two smallest absolute values of importance scores. 490*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured( 491*da0073e9SAndroid Build Coastguard Worker m, "weight", amount=2, importance_scores=importance_scores 492*da0073e9SAndroid Build Coastguard Worker ) 493*da0073e9SAndroid Build Coastguard Worker expected_weight = torch.tensor( 494*da0073e9SAndroid Build Coastguard Worker [[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype 495*da0073e9SAndroid Build Coastguard Worker ) 496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_weight, m.weight) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker def test_unstructured_pruning_same_magnitude(self): 499*da0073e9SAndroid Build Coastguard Worker r"""Since it may happen that the tensor to prune has entries with the 500*da0073e9SAndroid Build Coastguard Worker same exact magnitude, it is important to check that pruning happens 501*da0073e9SAndroid Build Coastguard Worker consistenly based on the bottom % of weights, and not by threshold, 502*da0073e9SAndroid Build Coastguard Worker which would instead kill off *all* units with magnitude = threshold. 503*da0073e9SAndroid Build Coastguard Worker """ 504*da0073e9SAndroid Build Coastguard Worker AMOUNT = 0.2 505*da0073e9SAndroid Build Coastguard Worker p = prune.L1Unstructured(amount=AMOUNT) 506*da0073e9SAndroid Build Coastguard Worker # create a random tensors with entries in {-2, 0, 2} 507*da0073e9SAndroid Build Coastguard Worker t = 2 * torch.randint(low=-1, high=2, size=(10, 7)) 508*da0073e9SAndroid Build Coastguard Worker nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement()) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) 511*da0073e9SAndroid Build Coastguard Worker nparams_pruned = torch.sum(computed_mask == 0) 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nparams_toprune, nparams_pruned) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker def test_random_structured_pruning_amount(self): 515*da0073e9SAndroid Build Coastguard Worker AMOUNT = 0.6 516*da0073e9SAndroid Build Coastguard Worker AXIS = 2 517*da0073e9SAndroid Build Coastguard Worker p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) 518*da0073e9SAndroid Build Coastguard Worker t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=torch.float32) 519*da0073e9SAndroid Build Coastguard Worker nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS]) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) 522*da0073e9SAndroid Build Coastguard Worker # check that 1 column is fully prune, the others are left untouched 523*da0073e9SAndroid Build Coastguard Worker remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS] 524*da0073e9SAndroid Build Coastguard Worker per_column_sums = sorted(torch.sum(computed_mask == 0, axis=remaining_axes)) 525*da0073e9SAndroid Build Coastguard Worker assert per_column_sums == [0, 20] 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker def test_ln_structured_pruning(self): 528*da0073e9SAndroid Build Coastguard Worker r"""Check Ln structured pruning by hand.""" 529*da0073e9SAndroid Build Coastguard Worker m = nn.Conv2d(3, 1, 2) 530*da0073e9SAndroid Build Coastguard Worker m.weight.data = torch.tensor( 531*da0073e9SAndroid Build Coastguard Worker [ 532*da0073e9SAndroid Build Coastguard Worker [ 533*da0073e9SAndroid Build Coastguard Worker [[1.0, 2.0], [1.0, 2.5]], 534*da0073e9SAndroid Build Coastguard Worker [[0.5, 1.0], [0.1, 0.1]], 535*da0073e9SAndroid Build Coastguard Worker [[-3.0, -5.0], [0.1, -1.0]], 536*da0073e9SAndroid Build Coastguard Worker ] 537*da0073e9SAndroid Build Coastguard Worker ] 538*da0073e9SAndroid Build Coastguard Worker ) 539*da0073e9SAndroid Build Coastguard Worker # expected effect of pruning 1 of the 3 channels by L2-norm 540*da0073e9SAndroid Build Coastguard Worker expected_mask_axis1 = torch.ones_like(m.weight) 541*da0073e9SAndroid Build Coastguard Worker expected_mask_axis1[:, 1] = 0.0 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker prune.ln_structured(m, "weight", amount=1, n=2, dim=1) 544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mask_axis1, m.weight_mask) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm 547*da0073e9SAndroid Build Coastguard Worker expected_mask_axis3 = expected_mask_axis1 548*da0073e9SAndroid Build Coastguard Worker expected_mask_axis3[:, :, :, 0] = 0.0 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker prune.ln_structured(m, "weight", amount=1, n=1, dim=-1) 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mask_axis3, m.weight_mask) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker def test_ln_structured_pruning_importance_scores(self): 554*da0073e9SAndroid Build Coastguard Worker r"""Check Ln structured pruning by hand.""" 555*da0073e9SAndroid Build Coastguard Worker m = nn.Conv2d(3, 1, 2) 556*da0073e9SAndroid Build Coastguard Worker m.weight.data = torch.tensor( 557*da0073e9SAndroid Build Coastguard Worker [ 558*da0073e9SAndroid Build Coastguard Worker [ 559*da0073e9SAndroid Build Coastguard Worker [[1.0, 2.0], [1.0, 2.5]], 560*da0073e9SAndroid Build Coastguard Worker [[0.5, 1.0], [0.1, 0.1]], 561*da0073e9SAndroid Build Coastguard Worker [[-3.0, -5.0], [0.1, -1.0]], 562*da0073e9SAndroid Build Coastguard Worker ] 563*da0073e9SAndroid Build Coastguard Worker ] 564*da0073e9SAndroid Build Coastguard Worker ) 565*da0073e9SAndroid Build Coastguard Worker importance_scores = torch.tensor( 566*da0073e9SAndroid Build Coastguard Worker [ 567*da0073e9SAndroid Build Coastguard Worker [ 568*da0073e9SAndroid Build Coastguard Worker [[10.0, 1.0], [10.0, 1.0]], 569*da0073e9SAndroid Build Coastguard Worker [[30.0, 3.0], [30.0, 3.0]], 570*da0073e9SAndroid Build Coastguard Worker [[-20.0, -2.0], [-20.0, -2.0]], 571*da0073e9SAndroid Build Coastguard Worker ] 572*da0073e9SAndroid Build Coastguard Worker ] 573*da0073e9SAndroid Build Coastguard Worker ) 574*da0073e9SAndroid Build Coastguard Worker # expected effect of pruning 1 of the 3 channels by L2-norm 575*da0073e9SAndroid Build Coastguard Worker expected_mask_axis1 = torch.ones_like(m.weight) 576*da0073e9SAndroid Build Coastguard Worker expected_mask_axis1[:, 0] = 0.0 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker prune.ln_structured( 579*da0073e9SAndroid Build Coastguard Worker m, "weight", amount=1, n=2, dim=1, importance_scores=importance_scores 580*da0073e9SAndroid Build Coastguard Worker ) 581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mask_axis1, m.weight_mask) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm 584*da0073e9SAndroid Build Coastguard Worker expected_mask_axis3 = expected_mask_axis1 585*da0073e9SAndroid Build Coastguard Worker expected_mask_axis3[:, :, :, 1] = 0.0 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker prune.ln_structured( 588*da0073e9SAndroid Build Coastguard Worker m, "weight", amount=1, n=1, dim=-1, importance_scores=importance_scores 589*da0073e9SAndroid Build Coastguard Worker ) 590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mask_axis3, m.weight_mask) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker def test_remove_pruning(self): 593*da0073e9SAndroid Build Coastguard Worker r"""`prune.remove` removes the hook and the reparametrization 594*da0073e9SAndroid Build Coastguard Worker and makes the pruning final in the original parameter. 595*da0073e9SAndroid Build Coastguard Worker """ 596*da0073e9SAndroid Build Coastguard Worker modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 597*da0073e9SAndroid Build Coastguard Worker names = ["weight", "bias"] 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker for m in modules: 600*da0073e9SAndroid Build Coastguard Worker for name in names: 601*da0073e9SAndroid Build Coastguard Worker with self.subTest(m=m, name=name): 602*da0073e9SAndroid Build Coastguard Worker # first prune 603*da0073e9SAndroid Build Coastguard Worker prune.random_unstructured(m, name, amount=0.5) 604*da0073e9SAndroid Build Coastguard Worker self.assertIn(name + "_orig", dict(m.named_parameters())) 605*da0073e9SAndroid Build Coastguard Worker self.assertIn(name + "_mask", dict(m.named_buffers())) 606*da0073e9SAndroid Build Coastguard Worker self.assertNotIn(name, dict(m.named_parameters())) 607*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(m, name)) 608*da0073e9SAndroid Build Coastguard Worker pruned_t = getattr(m, name) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker # then remove pruning 611*da0073e9SAndroid Build Coastguard Worker prune.remove(m, name) 612*da0073e9SAndroid Build Coastguard Worker self.assertIn(name, dict(m.named_parameters())) 613*da0073e9SAndroid Build Coastguard Worker self.assertNotIn(name + "_orig", dict(m.named_parameters())) 614*da0073e9SAndroid Build Coastguard Worker self.assertNotIn(name + "_mask", dict(m.named_buffers())) 615*da0073e9SAndroid Build Coastguard Worker final_t = getattr(m, name) 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pruned_t, final_t) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker def test_remove_pruning_exception(self): 620*da0073e9SAndroid Build Coastguard Worker r"""Removing from an unpruned tensor throws an assertion error""" 621*da0073e9SAndroid Build Coastguard Worker modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 622*da0073e9SAndroid Build Coastguard Worker names = ["weight", "bias"] 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker for m in modules: 625*da0073e9SAndroid Build Coastguard Worker for name in names: 626*da0073e9SAndroid Build Coastguard Worker with self.subTest(m=m, name=name): 627*da0073e9SAndroid Build Coastguard Worker # check that the module isn't pruned 628*da0073e9SAndroid Build Coastguard Worker self.assertFalse(prune.is_pruned(m)) 629*da0073e9SAndroid Build Coastguard Worker # since it isn't pruned, pruning can't be removed from it 630*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 631*da0073e9SAndroid Build Coastguard Worker prune.remove(m, name) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker def test_global_pruning(self): 634*da0073e9SAndroid Build Coastguard Worker r"""Test that global l1 unstructured pruning over 2 parameters removes 635*da0073e9SAndroid Build Coastguard Worker the `amount=4` smallest global weights across the 2 parameters. 636*da0073e9SAndroid Build Coastguard Worker """ 637*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 2) 638*da0073e9SAndroid Build Coastguard Worker n = nn.Linear(3, 1) 639*da0073e9SAndroid Build Coastguard Worker # modify the weight matrices by hand 640*da0073e9SAndroid Build Coastguard Worker m.weight = torch.nn.Parameter( 641*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32) 642*da0073e9SAndroid Build Coastguard Worker ) 643*da0073e9SAndroid Build Coastguard Worker n.weight = torch.nn.Parameter( 644*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32) 645*da0073e9SAndroid Build Coastguard Worker ) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker params_to_prune = ( 648*da0073e9SAndroid Build Coastguard Worker (m, "weight"), 649*da0073e9SAndroid Build Coastguard Worker (n, "weight"), 650*da0073e9SAndroid Build Coastguard Worker ) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker # prune the 4 smallest weights globally by L1 magnitude 653*da0073e9SAndroid Build Coastguard Worker prune.global_unstructured( 654*da0073e9SAndroid Build Coastguard Worker params_to_prune, pruning_method=prune.L1Unstructured, amount=4 655*da0073e9SAndroid Build Coastguard Worker ) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker expected_mweight = torch.tensor( 658*da0073e9SAndroid Build Coastguard Worker [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype 659*da0073e9SAndroid Build Coastguard Worker ) 660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_mweight, m.weight) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) 663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_nweight, n.weight) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker def test_global_pruning_importance_scores(self): 666*da0073e9SAndroid Build Coastguard Worker r"""Test that global l1 unstructured pruning over 2 parameters removes 667*da0073e9SAndroid Build Coastguard Worker the `amount=4` smallest global weights across the 2 parameters. 668*da0073e9SAndroid Build Coastguard Worker """ 669*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 2) 670*da0073e9SAndroid Build Coastguard Worker n = nn.Linear(3, 1) 671*da0073e9SAndroid Build Coastguard Worker # modify the weight matrices by hand 672*da0073e9SAndroid Build Coastguard Worker m.weight = torch.nn.Parameter( 673*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32) 674*da0073e9SAndroid Build Coastguard Worker ) 675*da0073e9SAndroid Build Coastguard Worker m_importance_scores = torch.tensor( 676*da0073e9SAndroid Build Coastguard Worker [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 677*da0073e9SAndroid Build Coastguard Worker ) 678*da0073e9SAndroid Build Coastguard Worker n.weight = torch.nn.Parameter( 679*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32) 680*da0073e9SAndroid Build Coastguard Worker ) 681*da0073e9SAndroid Build Coastguard Worker n_importance_scores = torch.tensor([[0, 10.0, -0.2]]).to(dtype=torch.float32) 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker params_to_prune = ( 684*da0073e9SAndroid Build Coastguard Worker (m, "weight"), 685*da0073e9SAndroid Build Coastguard Worker (n, "weight"), 686*da0073e9SAndroid Build Coastguard Worker ) 687*da0073e9SAndroid Build Coastguard Worker importance_scores = { 688*da0073e9SAndroid Build Coastguard Worker (m, "weight"): m_importance_scores, 689*da0073e9SAndroid Build Coastguard Worker (n, "weight"): n_importance_scores, 690*da0073e9SAndroid Build Coastguard Worker } 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker # prune the 4 smallest weights globally by L1 magnitude 693*da0073e9SAndroid Build Coastguard Worker prune.global_unstructured( 694*da0073e9SAndroid Build Coastguard Worker params_to_prune, 695*da0073e9SAndroid Build Coastguard Worker pruning_method=prune.L1Unstructured, 696*da0073e9SAndroid Build Coastguard Worker amount=4, 697*da0073e9SAndroid Build Coastguard Worker importance_scores=importance_scores, 698*da0073e9SAndroid Build Coastguard Worker ) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker expected_m_weight = torch.tensor( 701*da0073e9SAndroid Build Coastguard Worker [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype 702*da0073e9SAndroid Build Coastguard Worker ) 703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_m_weight, m.weight) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) 706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_n_weight, n.weight) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker def test_custom_from_mask_pruning(self): 709*da0073e9SAndroid Build Coastguard Worker r"""Test that the CustomFromMask is capable of receiving 710*da0073e9SAndroid Build Coastguard Worker as input at instantiation time a custom mask, and combining it with 711*da0073e9SAndroid Build Coastguard Worker the previous default mask to generate the correct final mask. 712*da0073e9SAndroid Build Coastguard Worker """ 713*da0073e9SAndroid Build Coastguard Worker # new mask 714*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]]) 715*da0073e9SAndroid Build Coastguard Worker # old mask 716*da0073e9SAndroid Build Coastguard Worker default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]]) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker # some tensor (not actually used) 719*da0073e9SAndroid Build Coastguard Worker t = torch.rand_like(mask.to(dtype=torch.float32)) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker p = prune.CustomFromMask(mask=mask) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker computed_mask = p.compute_mask(t, default_mask) 724*da0073e9SAndroid Build Coastguard Worker expected_mask = torch.tensor( 725*da0073e9SAndroid Build Coastguard Worker [[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype 726*da0073e9SAndroid Build Coastguard Worker ) 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker self.assertEqual(computed_mask, expected_mask) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker def test_pruning_rollback(self): 731*da0073e9SAndroid Build Coastguard Worker r"""Test that if something fails when the we try to compute the mask, 732*da0073e9SAndroid Build Coastguard Worker then the model isn't left in some intermediate half-pruned state. 733*da0073e9SAndroid Build Coastguard Worker The try/except statement in `apply` should handle rolling back 734*da0073e9SAndroid Build Coastguard Worker to the previous state before pruning began. 735*da0073e9SAndroid Build Coastguard Worker """ 736*da0073e9SAndroid Build Coastguard Worker modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] 737*da0073e9SAndroid Build Coastguard Worker names = ["weight", "bias"] 738*da0073e9SAndroid Build Coastguard Worker 739*da0073e9SAndroid Build Coastguard Worker for m in modules: 740*da0073e9SAndroid Build Coastguard Worker for name in names: 741*da0073e9SAndroid Build Coastguard Worker with self.subTest(m=m, name=name): 742*da0073e9SAndroid Build Coastguard Worker with mock.patch( 743*da0073e9SAndroid Build Coastguard Worker "torch.nn.utils.prune.L1Unstructured.compute_mask" 744*da0073e9SAndroid Build Coastguard Worker ) as compute_mask: 745*da0073e9SAndroid Build Coastguard Worker compute_mask.side_effect = Exception("HA!") 746*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 747*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured(m, name=name, amount=0.9) 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name in dict(m.named_parameters())) 750*da0073e9SAndroid Build Coastguard Worker self.assertFalse(name + "_mask" in dict(m.named_buffers())) 751*da0073e9SAndroid Build Coastguard Worker self.assertFalse(name + "_orig" in dict(m.named_parameters())) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker def test_pruning_serialization_model(self): 754*da0073e9SAndroid Build Coastguard Worker # create a model 755*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 756*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 757*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 758*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 1), 759*da0073e9SAndroid Build Coastguard Worker ) 760*da0073e9SAndroid Build Coastguard Worker # check that everything looks normal before pruning 761*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_orig", model.state_dict()) 762*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_mask", model.state_dict()) 763*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight", model.state_dict()) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker # prune one of its parameters 766*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured(module=model[0], name="weight", amount=0.9) 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker # check that the original weight and the new mask are present 769*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight_orig", model.state_dict()) 770*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight_mask", model.state_dict()) 771*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight", model.state_dict()) 772*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model[0], "weight")) 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker pruned_weight = model[0].weight 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 777*da0073e9SAndroid Build Coastguard Worker torch.save(model, fname) 778*da0073e9SAndroid Build Coastguard Worker # weights_only=False as this is legacy code that saves the model 779*da0073e9SAndroid Build Coastguard Worker new_model = torch.load(fname, weights_only=False) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker # check that the original weight and the new mask are present 782*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight_orig", new_model.state_dict()) 783*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight_mask", new_model.state_dict()) 784*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight", new_model.state_dict()) 785*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(new_model[0], "weight")) 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pruned_weight, new_model[0].weight) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker def test_pruning_serialization_state_dict(self): 790*da0073e9SAndroid Build Coastguard Worker # create a model 791*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 792*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 793*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 794*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 1), 795*da0073e9SAndroid Build Coastguard Worker ) 796*da0073e9SAndroid Build Coastguard Worker # check that everything looks normal before pruning 797*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_orig", model.state_dict()) 798*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_mask", model.state_dict()) 799*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight", model.state_dict()) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker # prune one of its parameters 802*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured(module=model[0], name="weight", amount=0.9) 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker # check that the original weight and the new mask are present 805*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight_orig", model.state_dict()) 806*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight_mask", model.state_dict()) 807*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight", model.state_dict()) 808*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(model[0], "weight")) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker pruned_weight = model[0].weight 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker # make pruning permanent and restore parameter names as in base 813*da0073e9SAndroid Build Coastguard Worker # architecture 814*da0073e9SAndroid Build Coastguard Worker prune.remove(module=model[0], name="weight") 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker # check that the original weight and the new mask are no longer present 817*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_orig", model.state_dict()) 818*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_mask", model.state_dict()) 819*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight", model.state_dict()) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker # save the state dict of model and reload it into new_model 822*da0073e9SAndroid Build Coastguard Worker new_model = torch.nn.Sequential( 823*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 824*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 825*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 1), 826*da0073e9SAndroid Build Coastguard Worker ) 827*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 828*da0073e9SAndroid Build Coastguard Worker torch.save(model.state_dict(), fname) 829*da0073e9SAndroid Build Coastguard Worker new_model.load_state_dict(torch.load(fname)) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker # check that the original weight and the new mask are not present in 832*da0073e9SAndroid Build Coastguard Worker # new_model either. 833*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_orig", new_model.state_dict()) 834*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("0.weight_mask", new_model.state_dict()) 835*da0073e9SAndroid Build Coastguard Worker self.assertIn("0.weight", new_model.state_dict()) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pruned_weight, new_model[0].weight) 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker def test_prune(self): 840*da0073e9SAndroid Build Coastguard Worker # create a new pruning method 841*da0073e9SAndroid Build Coastguard Worker p = prune.L1Unstructured(amount=2) 842*da0073e9SAndroid Build Coastguard Worker # create tensor to be pruned 843*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 844*da0073e9SAndroid Build Coastguard Worker # create prior mask by hand 845*da0073e9SAndroid Build Coastguard Worker default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 846*da0073e9SAndroid Build Coastguard Worker # since we are pruning the two lowest magnitude units, the outcome of 847*da0073e9SAndroid Build Coastguard Worker # the calculation should be this: 848*da0073e9SAndroid Build Coastguard Worker expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) 849*da0073e9SAndroid Build Coastguard Worker pruned_tensor = p.prune(t, default_mask) 850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t * expected_mask, pruned_tensor) 851*da0073e9SAndroid Build Coastguard Worker 852*da0073e9SAndroid Build Coastguard Worker def test_prune_importance_scores(self): 853*da0073e9SAndroid Build Coastguard Worker # create a new pruning method 854*da0073e9SAndroid Build Coastguard Worker p = prune.L1Unstructured(amount=2) 855*da0073e9SAndroid Build Coastguard Worker # create tensor to be pruned 856*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 857*da0073e9SAndroid Build Coastguard Worker importance_scores = torch.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to( 858*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker # create prior mask by hand 861*da0073e9SAndroid Build Coastguard Worker default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 862*da0073e9SAndroid Build Coastguard Worker # since we are pruning the two lowest magnitude units, the outcome of 863*da0073e9SAndroid Build Coastguard Worker # the calculation should be this: 864*da0073e9SAndroid Build Coastguard Worker expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) 865*da0073e9SAndroid Build Coastguard Worker pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) 866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t * expected_mask, pruned_tensor) 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker def test_prune_importance_scores_mimic_default(self): 869*da0073e9SAndroid Build Coastguard Worker # create a new pruning method 870*da0073e9SAndroid Build Coastguard Worker p = prune.L1Unstructured(amount=2) 871*da0073e9SAndroid Build Coastguard Worker # create tensor to be pruned 872*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) 873*da0073e9SAndroid Build Coastguard Worker # create prior mask by hand 874*da0073e9SAndroid Build Coastguard Worker default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) 875*da0073e9SAndroid Build Coastguard Worker # since we are pruning the two lowest magnitude units, the outcome of 876*da0073e9SAndroid Build Coastguard Worker # the calculation should be this: 877*da0073e9SAndroid Build Coastguard Worker expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) 878*da0073e9SAndroid Build Coastguard Worker pruned_tensor_without_importance_scores = p.prune(t, default_mask) 879*da0073e9SAndroid Build Coastguard Worker pruned_tensor_with_importance_scores = p.prune( 880*da0073e9SAndroid Build Coastguard Worker t, default_mask, importance_scores=t 881*da0073e9SAndroid Build Coastguard Worker ) 882*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 883*da0073e9SAndroid Build Coastguard Worker pruned_tensor_without_importance_scores, 884*da0073e9SAndroid Build Coastguard Worker pruned_tensor_with_importance_scores, 885*da0073e9SAndroid Build Coastguard Worker ) 886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker def test_rnn_pruning(self): 889*da0073e9SAndroid Build Coastguard Worker l = torch.nn.LSTM(32, 32) 890*da0073e9SAndroid Build Coastguard Worker # This Module has 4 parameters called: 891*da0073e9SAndroid Build Coastguard Worker # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker # Pruning one of them causes one of the weights to become a tensor 894*da0073e9SAndroid Build Coastguard Worker prune.l1_unstructured(l, "weight_ih_l0", 0.5) 895*da0073e9SAndroid Build Coastguard Worker assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 3 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker # Removing the pruning reparametrization restores the Parameter 898*da0073e9SAndroid Build Coastguard Worker prune.remove(l, "weight_ih_l0") 899*da0073e9SAndroid Build Coastguard Worker assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 4 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker # Make sure that, upon removal of the reparametrization, the 902*da0073e9SAndroid Build Coastguard Worker # `._parameters` and `.named_parameters` contain the right params. 903*da0073e9SAndroid Build Coastguard Worker # Specifically, the original weight ('weight_ih_l0') should be placed 904*da0073e9SAndroid Build Coastguard Worker # back in the parameters, while the reparametrization component 905*da0073e9SAndroid Build Coastguard Worker # ('weight_ih_l0_orig') should be removed. 906*da0073e9SAndroid Build Coastguard Worker assert "weight_ih_l0" in l._parameters 907*da0073e9SAndroid Build Coastguard Worker assert l._parameters["weight_ih_l0"] is not None 908*da0073e9SAndroid Build Coastguard Worker assert "weight_ih_l0_orig" not in l._parameters 909*da0073e9SAndroid Build Coastguard Worker assert "weight_ih_l0" in dict(l.named_parameters()) 910*da0073e9SAndroid Build Coastguard Worker assert dict(l.named_parameters())["weight_ih_l0"] is not None 911*da0073e9SAndroid Build Coastguard Worker assert "weight_ih_l0_orig" not in dict(l.named_parameters()) 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestPruningNN) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 917*da0073e9SAndroid Build Coastguard Worker run_tests() 918