xref: /aosp_15_r20/external/pytorch/test/nn/test_pruning.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import pickle
3import unittest
4import unittest.mock as mock
5
6import torch
7import torch.nn as nn
8import torch.nn.utils.prune as prune
9from torch.testing._internal.common_nn import NNTestCase
10from torch.testing._internal.common_utils import (
11    instantiate_parametrized_tests,
12    run_tests,
13    TemporaryFileName,
14    TEST_NUMPY,
15)
16
17
18class TestPruningNN(NNTestCase):
19    _do_cuda_memory_leak_check = True
20    _do_cuda_non_default_stream = True
21
22    # torch/nn/utils/prune.py
23    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
24    def test_validate_pruning_amount_init(self):
25        r"""Test the first util function that validates the pruning
26        amount requested by the user the moment the pruning method
27        is initialized. This test checks that the expected errors are
28        raised whenever the amount is invalid.
29        The original function runs basic type checking + value range checks.
30        It doesn't check the validity of the pruning amount with
31        respect to the size of the tensor to prune. That's left to
32        `_validate_pruning_amount`, tested below.
33        """
34        # neither float not int should raise TypeError
35        with self.assertRaises(TypeError):
36            prune._validate_pruning_amount_init(amount="I'm a string")
37
38        # float not in [0, 1] should raise ValueError
39        with self.assertRaises(ValueError):
40            prune._validate_pruning_amount_init(amount=1.1)
41        with self.assertRaises(ValueError):
42            prune._validate_pruning_amount_init(amount=20.0)
43
44        # negative int should raise ValueError
45        with self.assertRaises(ValueError):
46            prune._validate_pruning_amount_init(amount=-10)
47
48        # all these should pass without errors because they're valid amounts
49        prune._validate_pruning_amount_init(amount=0.34)
50        prune._validate_pruning_amount_init(amount=1500)
51        prune._validate_pruning_amount_init(amount=0)
52        prune._validate_pruning_amount_init(amount=0.0)
53        prune._validate_pruning_amount_init(amount=1)
54        prune._validate_pruning_amount_init(amount=1.0)
55        self.assertTrue(True)
56
57    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
58    def test_validate_pruning_amount(self):
59        r"""Tests the second util function that validates the pruning
60        amount requested by the user, this time with respect to the size
61        of the tensor to prune. The rationale is that if the pruning amount,
62        converted to absolute value of units to prune, is larger than
63        the number of units in the tensor, then we expect the util function
64        to raise a value error.
65        """
66        # if amount is int and amount > tensor_size, raise ValueError
67        with self.assertRaises(ValueError):
68            prune._validate_pruning_amount(amount=20, tensor_size=19)
69
70        # amount is a float so this should not raise an error
71        prune._validate_pruning_amount(amount=0.3, tensor_size=0)
72
73        # this is okay
74        prune._validate_pruning_amount(amount=19, tensor_size=20)
75        prune._validate_pruning_amount(amount=0, tensor_size=0)
76        prune._validate_pruning_amount(amount=1, tensor_size=1)
77        self.assertTrue(True)
78
79    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
80    def test_compute_nparams_to_prune(self):
81        r"""Test that requested pruning `amount` gets translated into the
82        correct absolute number of units to prune.
83        """
84        self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0)
85        self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10)
86        # if 1 is int, means 1 unit
87        self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1)
88        # if 1. is float, means 100% of units
89        self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15)
90        self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7)
91
92    def test_random_pruning_sizes(self):
93        r"""Test that the new parameters and buffers created by the pruning
94        method have the same size as the input tensor to prune. These, in
95        fact, correspond to the pruned version of the tensor itself, its
96        mask, and its original copy, so the size must match.
97        """
98        # fixturize test
99        # TODO: add other modules
100        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
101        names = ["weight", "bias"]
102
103        for m in modules:
104            for name in names:
105                with self.subTest(m=m, name=name):
106                    original_tensor = getattr(m, name)
107
108                    prune.random_unstructured(m, name=name, amount=0.1)
109                    # mask has the same size as tensor being pruned
110                    self.assertEqual(
111                        original_tensor.size(), getattr(m, name + "_mask").size()
112                    )
113                    # 'orig' tensor has the same size as the original tensor
114                    self.assertEqual(
115                        original_tensor.size(), getattr(m, name + "_orig").size()
116                    )
117                    # new tensor has the same size as the original tensor
118                    self.assertEqual(original_tensor.size(), getattr(m, name).size())
119
120    def test_random_pruning_orig(self):
121        r"""Test that original tensor is correctly stored in 'orig'
122        after pruning is applied. Important to make sure we don't
123        lose info about the original unpruned parameter.
124        """
125        # fixturize test
126        # TODO: add other modules
127        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
128        names = ["weight", "bias"]
129
130        for m in modules:
131            for name in names:
132                with self.subTest(m=m, name=name):
133                    # tensor prior to pruning
134                    original_tensor = getattr(m, name)
135                    prune.random_unstructured(m, name=name, amount=0.1)
136                    self.assertEqual(original_tensor, getattr(m, name + "_orig"))
137
138    def test_random_pruning_new_weight(self):
139        r"""Test that module.name now contains a pruned version of
140        the original tensor obtained from multiplying it by the mask.
141        """
142        # fixturize test
143        # TODO: add other modules
144        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
145        names = ["weight", "bias"]
146
147        for m in modules:
148            for name in names:
149                with self.subTest(m=m, name=name):
150                    # tensor prior to pruning
151                    original_tensor = getattr(m, name)
152                    prune.random_unstructured(m, name=name, amount=0.1)
153                    # weight = weight_orig * weight_mask
154                    self.assertEqual(
155                        getattr(m, name),
156                        getattr(m, name + "_orig")
157                        * getattr(m, name + "_mask").to(dtype=original_tensor.dtype),
158                    )
159
160    def test_identity_pruning(self):
161        r"""Test that a mask of 1s does not change forward or backward."""
162        input_ = torch.ones(1, 5)
163        m = nn.Linear(5, 2)
164        y_prepruning = m(input_)  # output prior to pruning
165
166        # compute grad pre-pruning and check it's equal to all ones
167        y_prepruning.sum().backward()
168        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!
169        self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
170        old_grad_bias = m.bias.grad.clone()
171        self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
172
173        # remove grads
174        m.zero_grad()
175
176        # force the mask to be made of all 1s
177        prune.identity(m, name="weight")
178
179        # with mask of 1s, output should be identical to no mask
180        y_postpruning = m(input_)
181        self.assertEqual(y_prepruning, y_postpruning)
182
183        # with mask of 1s, grad should be identical to no mask
184        y_postpruning.sum().backward()
185        self.assertEqual(old_grad_weight, m.weight_orig.grad)
186        self.assertEqual(old_grad_bias, m.bias.grad)
187
188        # calling forward twice in a row shouldn't change output
189        y1 = m(input_)
190        y2 = m(input_)
191        self.assertEqual(y1, y2)
192
193    def test_random_pruning_0perc(self):
194        r"""Test that a mask of 1s does not change forward or backward."""
195        input_ = torch.ones(1, 5)
196        m = nn.Linear(5, 2)
197        y_prepruning = m(input_)  # output prior to pruning
198
199        # compute grad pre-pruning and check it's equal to all ones
200        y_prepruning.sum().backward()
201        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!
202        self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
203        old_grad_bias = m.bias.grad.clone()
204        self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
205
206        # remove grads
207        m.zero_grad()
208
209        # force the mask to be made of all 1s
210        with mock.patch(
211            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
212        ) as compute_mask:
213            compute_mask.return_value = torch.ones_like(m.weight)
214            prune.random_unstructured(
215                m, name="weight", amount=0.9
216            )  # amount won't count
217
218        # with mask of 1s, output should be identical to no mask
219        y_postpruning = m(input_)
220        self.assertEqual(y_prepruning, y_postpruning)
221
222        # with mask of 1s, grad should be identical to no mask
223        y_postpruning.sum().backward()
224        self.assertEqual(old_grad_weight, m.weight_orig.grad)
225        self.assertEqual(old_grad_bias, m.bias.grad)
226
227        # calling forward twice in a row shouldn't change output
228        y1 = m(input_)
229        y2 = m(input_)
230        self.assertEqual(y1, y2)
231
232    def test_random_pruning(self):
233        input_ = torch.ones(1, 5)
234        m = nn.Linear(5, 2)
235
236        # define custom mask to assign with mock
237        mask = torch.ones_like(m.weight)
238        mask[1, 0] = 0
239        mask[0, 3] = 0
240
241        # check grad is zero for masked weights
242        with mock.patch(
243            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
244        ) as compute_mask:
245            compute_mask.return_value = mask
246            prune.random_unstructured(m, name="weight", amount=0.9)
247
248        y_postpruning = m(input_)
249        y_postpruning.sum().backward()
250        # weight_orig is the parameter, so it's the tensor that will accumulate the grad
251        self.assertEqual(m.weight_orig.grad, mask)  # all 1s, except for masked units
252        self.assertEqual(m.bias.grad, torch.ones_like(m.bias))
253
254        # make sure that weight_orig update doesn't modify [1, 0] and [0, 3]
255        old_weight_orig = m.weight_orig.clone()
256        # update weights
257        learning_rate = 1.0
258        for p in m.parameters():
259            p.data.sub_(p.grad.data * learning_rate)
260        # since these are pruned, they should not be updated
261        self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0])
262        self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3])
263
264    def test_random_pruning_forward(self):
265        r"""check forward with mask (by hand)."""
266        input_ = torch.ones(1, 5)
267        m = nn.Linear(5, 2)
268
269        # define custom mask to assign with mock
270        mask = torch.zeros_like(m.weight)
271        mask[1, 0] = 1
272        mask[0, 3] = 1
273
274        with mock.patch(
275            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
276        ) as compute_mask:
277            compute_mask.return_value = mask
278            prune.random_unstructured(m, name="weight", amount=0.9)
279
280        yhat = m(input_)
281        self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0])
282        self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1])
283
284    def test_remove_pruning_forward(self):
285        r"""Remove pruning and check forward is unchanged from previous
286        pruned state.
287        """
288        input_ = torch.ones(1, 5)
289        m = nn.Linear(5, 2)
290
291        # define custom mask to assign with mock
292        mask = torch.ones_like(m.weight)
293        mask[1, 0] = 0
294        mask[0, 3] = 0
295
296        # check grad is zero for masked weights
297        with mock.patch(
298            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
299        ) as compute_mask:
300            compute_mask.return_value = mask
301            prune.random_unstructured(m, name="weight", amount=0.9)
302
303        y_postpruning = m(input_)
304
305        prune.remove(m, "weight")
306
307        y_postremoval = m(input_)
308        self.assertEqual(y_postpruning, y_postremoval)
309
310    def test_pruning_id_consistency(self):
311        r"""Test that pruning doesn't change the id of the parameters, which
312        would otherwise introduce issues with pre-existing optimizers that
313        point to old parameters.
314        """
315        m = nn.Linear(5, 2, bias=False)
316
317        tensor_id = id(next(iter(m.parameters())))
318
319        prune.random_unstructured(m, name="weight", amount=0.9)
320        self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
321
322        prune.remove(m, "weight")
323        self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
324
325    def test_random_pruning_pickle(self):
326        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
327        names = ["weight", "bias"]
328
329        for m in modules:
330            for name in names:
331                with self.subTest(m=m, name=name):
332                    prune.random_unstructured(m, name=name, amount=0.1)
333                    m_new = pickle.loads(pickle.dumps(m))
334                    self.assertIsInstance(m_new, type(m))
335
336    def test_multiple_pruning_calls(self):
337        # if you call pruning twice, the hook becomes a PruningContainer
338        m = nn.Conv3d(2, 2, 2)
339        prune.l1_unstructured(m, name="weight", amount=0.1)
340        weight_mask0 = m.weight_mask  # save it for later sanity check
341
342        # prune again
343        prune.ln_structured(m, name="weight", amount=0.3, n=2, dim=0)
344        hook = next(iter(m._forward_pre_hooks.values()))
345        self.assertIsInstance(hook, torch.nn.utils.prune.PruningContainer)
346        # check that container._tensor_name is correctly set no matter how
347        # many pruning methods are in the container
348        self.assertEqual(hook._tensor_name, "weight")
349
350        # check that the pruning container has the right length
351        # equal to the number of pruning iters
352        self.assertEqual(len(hook), 2)  # m.weight has been pruned twice
353
354        # check that the entries of the pruning container are of the expected
355        # type and in the expected order
356        self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured)
357        self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured)
358
359        # check that all entries that are 0 in the 1st mask are 0 in the
360        # 2nd mask too
361        self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0))
362
363        # prune again
364        prune.ln_structured(m, name="weight", amount=0.1, n=float("inf"), dim=1)
365        # check that container._tensor_name is correctly set no matter how
366        # many pruning methods are in the container
367        hook = next(iter(m._forward_pre_hooks.values()))
368        self.assertEqual(hook._tensor_name, "weight")
369
370    def test_pruning_container(self):
371        # create an empty container
372        container = prune.PruningContainer()
373        container._tensor_name = "test"
374        self.assertEqual(len(container), 0)
375
376        p = prune.L1Unstructured(amount=2)
377        p._tensor_name = "test"
378
379        # test adding a pruning method to a container
380        container.add_pruning_method(p)
381
382        # test error raised if tensor name is different
383        q = prune.L1Unstructured(amount=2)
384        q._tensor_name = "another_test"
385        with self.assertRaises(ValueError):
386            container.add_pruning_method(q)
387
388        # test that adding a non-pruning method object to a pruning container
389        # raises a TypeError
390        with self.assertRaises(TypeError):
391            container.add_pruning_method(10)
392        with self.assertRaises(TypeError):
393            container.add_pruning_method("ugh")
394
395    def test_pruning_container_compute_mask(self):
396        r"""Test `compute_mask` of pruning container with a known `t` and
397        `default_mask`. Indirectly checks that Ln structured pruning is
398        acting on the right axis.
399        """
400        # create an empty container
401        container = prune.PruningContainer()
402        container._tensor_name = "test"
403
404        # 1) test unstructured pruning
405        # create a new pruning method
406        p = prune.L1Unstructured(amount=2)
407        p._tensor_name = "test"
408        # add the pruning method to the container
409        container.add_pruning_method(p)
410
411        # create tensor to be pruned
412        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
413        # create prior mask by hand
414        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
415        # since we are pruning the two lowest magnitude units, the outcome of
416        # the calculation should be this:
417        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32)
418        computed_mask = container.compute_mask(t, default_mask)
419        self.assertEqual(expected_mask, computed_mask)
420
421        # 2) test structured pruning
422        q = prune.LnStructured(amount=1, n=2, dim=0)
423        q._tensor_name = "test"
424        container.add_pruning_method(q)
425        # since we are pruning the lowest magnitude one of the two rows, the
426        # outcome of the calculation should be this:
427        expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32)
428        computed_mask = container.compute_mask(t, default_mask)
429        self.assertEqual(expected_mask, computed_mask)
430
431        # 2) test structured pruning, along another axis
432        r = prune.LnStructured(amount=1, n=2, dim=1)
433        r._tensor_name = "test"
434        container.add_pruning_method(r)
435        # since we are pruning the lowest magnitude of the four columns, the
436        # outcome of the calculation should be this:
437        expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32)
438        computed_mask = container.compute_mask(t, default_mask)
439        self.assertEqual(expected_mask, computed_mask)
440
441    def test_l1_unstructured_pruning(self):
442        r"""Test that l1 unstructured pruning actually removes the lowest
443        entries by l1 norm (by hand). It also checks that applying l1
444        unstructured pruning more than once respects the previous mask.
445        """
446        m = nn.Linear(4, 2)
447        # modify its weight matrix by hand
448        m.weight = torch.nn.Parameter(
449            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
450        )
451
452        prune.l1_unstructured(m, "weight", amount=2)
453        expected_weight = torch.tensor(
454            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
455        )
456        self.assertEqual(expected_weight, m.weight)
457
458        # check that pruning again removes the next two smallest entries
459        prune.l1_unstructured(m, "weight", amount=2)
460        expected_weight = torch.tensor(
461            [[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype
462        )
463        self.assertEqual(expected_weight, m.weight)
464
465    def test_l1_unstructured_pruning_with_importance_scores(self):
466        r"""Test that l1 unstructured pruning actually removes the lowest
467        entries of importance scores and not the parameter by l1 norm (by hand).
468        It also checks that applying l1 unstructured pruning more than once
469        respects the previous mask.
470        """
471        m = nn.Linear(4, 2)
472        # modify its weight matrix by hand
473        m.weight = torch.nn.Parameter(
474            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
475        )
476        importance_scores = torch.tensor(
477            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
478        )
479
480        prune.l1_unstructured(
481            m, "weight", amount=2, importance_scores=importance_scores
482        )
483        expected_weight = torch.tensor(
484            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
485        )
486        self.assertEqual(expected_weight, m.weight)
487
488        # check that pruning again removes two entries of m.weight that are colocated with
489        # the next two smallest absolute values of importance scores.
490        prune.l1_unstructured(
491            m, "weight", amount=2, importance_scores=importance_scores
492        )
493        expected_weight = torch.tensor(
494            [[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype
495        )
496        self.assertEqual(expected_weight, m.weight)
497
498    def test_unstructured_pruning_same_magnitude(self):
499        r"""Since it may happen that the tensor to prune has entries with the
500        same exact magnitude, it is important to check that pruning happens
501        consistenly based on the bottom % of weights, and not by threshold,
502        which would instead kill off *all* units with magnitude = threshold.
503        """
504        AMOUNT = 0.2
505        p = prune.L1Unstructured(amount=AMOUNT)
506        # create a random tensors with entries in {-2, 0, 2}
507        t = 2 * torch.randint(low=-1, high=2, size=(10, 7))
508        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement())
509
510        computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
511        nparams_pruned = torch.sum(computed_mask == 0)
512        self.assertEqual(nparams_toprune, nparams_pruned)
513
514    def test_random_structured_pruning_amount(self):
515        AMOUNT = 0.6
516        AXIS = 2
517        p = prune.RandomStructured(amount=AMOUNT, dim=AXIS)
518        t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=torch.float32)
519        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS])
520
521        computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
522        # check that 1 column is fully prune, the others are left untouched
523        remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS]
524        per_column_sums = sorted(torch.sum(computed_mask == 0, axis=remaining_axes))
525        assert per_column_sums == [0, 20]
526
527    def test_ln_structured_pruning(self):
528        r"""Check Ln structured pruning by hand."""
529        m = nn.Conv2d(3, 1, 2)
530        m.weight.data = torch.tensor(
531            [
532                [
533                    [[1.0, 2.0], [1.0, 2.5]],
534                    [[0.5, 1.0], [0.1, 0.1]],
535                    [[-3.0, -5.0], [0.1, -1.0]],
536                ]
537            ]
538        )
539        # expected effect of pruning 1 of the 3 channels by L2-norm
540        expected_mask_axis1 = torch.ones_like(m.weight)
541        expected_mask_axis1[:, 1] = 0.0
542
543        prune.ln_structured(m, "weight", amount=1, n=2, dim=1)
544        self.assertEqual(expected_mask_axis1, m.weight_mask)
545
546        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm
547        expected_mask_axis3 = expected_mask_axis1
548        expected_mask_axis3[:, :, :, 0] = 0.0
549
550        prune.ln_structured(m, "weight", amount=1, n=1, dim=-1)
551        self.assertEqual(expected_mask_axis3, m.weight_mask)
552
553    def test_ln_structured_pruning_importance_scores(self):
554        r"""Check Ln structured pruning by hand."""
555        m = nn.Conv2d(3, 1, 2)
556        m.weight.data = torch.tensor(
557            [
558                [
559                    [[1.0, 2.0], [1.0, 2.5]],
560                    [[0.5, 1.0], [0.1, 0.1]],
561                    [[-3.0, -5.0], [0.1, -1.0]],
562                ]
563            ]
564        )
565        importance_scores = torch.tensor(
566            [
567                [
568                    [[10.0, 1.0], [10.0, 1.0]],
569                    [[30.0, 3.0], [30.0, 3.0]],
570                    [[-20.0, -2.0], [-20.0, -2.0]],
571                ]
572            ]
573        )
574        # expected effect of pruning 1 of the 3 channels by L2-norm
575        expected_mask_axis1 = torch.ones_like(m.weight)
576        expected_mask_axis1[:, 0] = 0.0
577
578        prune.ln_structured(
579            m, "weight", amount=1, n=2, dim=1, importance_scores=importance_scores
580        )
581        self.assertEqual(expected_mask_axis1, m.weight_mask)
582
583        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm
584        expected_mask_axis3 = expected_mask_axis1
585        expected_mask_axis3[:, :, :, 1] = 0.0
586
587        prune.ln_structured(
588            m, "weight", amount=1, n=1, dim=-1, importance_scores=importance_scores
589        )
590        self.assertEqual(expected_mask_axis3, m.weight_mask)
591
592    def test_remove_pruning(self):
593        r"""`prune.remove` removes the hook and the reparametrization
594        and makes the pruning final in the original parameter.
595        """
596        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
597        names = ["weight", "bias"]
598
599        for m in modules:
600            for name in names:
601                with self.subTest(m=m, name=name):
602                    # first prune
603                    prune.random_unstructured(m, name, amount=0.5)
604                    self.assertIn(name + "_orig", dict(m.named_parameters()))
605                    self.assertIn(name + "_mask", dict(m.named_buffers()))
606                    self.assertNotIn(name, dict(m.named_parameters()))
607                    self.assertTrue(hasattr(m, name))
608                    pruned_t = getattr(m, name)
609
610                    # then remove pruning
611                    prune.remove(m, name)
612                    self.assertIn(name, dict(m.named_parameters()))
613                    self.assertNotIn(name + "_orig", dict(m.named_parameters()))
614                    self.assertNotIn(name + "_mask", dict(m.named_buffers()))
615                    final_t = getattr(m, name)
616
617                    self.assertEqual(pruned_t, final_t)
618
619    def test_remove_pruning_exception(self):
620        r"""Removing from an unpruned tensor throws an assertion error"""
621        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
622        names = ["weight", "bias"]
623
624        for m in modules:
625            for name in names:
626                with self.subTest(m=m, name=name):
627                    # check that the module isn't pruned
628                    self.assertFalse(prune.is_pruned(m))
629                    # since it isn't pruned, pruning can't be removed from it
630                    with self.assertRaises(ValueError):
631                        prune.remove(m, name)
632
633    def test_global_pruning(self):
634        r"""Test that global l1 unstructured pruning over 2 parameters removes
635        the `amount=4` smallest global weights across the 2 parameters.
636        """
637        m = nn.Linear(4, 2)
638        n = nn.Linear(3, 1)
639        # modify the weight matrices by hand
640        m.weight = torch.nn.Parameter(
641            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
642        )
643        n.weight = torch.nn.Parameter(
644            torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
645        )
646
647        params_to_prune = (
648            (m, "weight"),
649            (n, "weight"),
650        )
651
652        # prune the 4 smallest weights globally by L1 magnitude
653        prune.global_unstructured(
654            params_to_prune, pruning_method=prune.L1Unstructured, amount=4
655        )
656
657        expected_mweight = torch.tensor(
658            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
659        )
660        self.assertEqual(expected_mweight, m.weight)
661
662        expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype)
663        self.assertEqual(expected_nweight, n.weight)
664
665    def test_global_pruning_importance_scores(self):
666        r"""Test that global l1 unstructured pruning over 2 parameters removes
667        the `amount=4` smallest global weights across the 2 parameters.
668        """
669        m = nn.Linear(4, 2)
670        n = nn.Linear(3, 1)
671        # modify the weight matrices by hand
672        m.weight = torch.nn.Parameter(
673            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
674        )
675        m_importance_scores = torch.tensor(
676            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
677        )
678        n.weight = torch.nn.Parameter(
679            torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
680        )
681        n_importance_scores = torch.tensor([[0, 10.0, -0.2]]).to(dtype=torch.float32)
682
683        params_to_prune = (
684            (m, "weight"),
685            (n, "weight"),
686        )
687        importance_scores = {
688            (m, "weight"): m_importance_scores,
689            (n, "weight"): n_importance_scores,
690        }
691
692        # prune the 4 smallest weights globally by L1 magnitude
693        prune.global_unstructured(
694            params_to_prune,
695            pruning_method=prune.L1Unstructured,
696            amount=4,
697            importance_scores=importance_scores,
698        )
699
700        expected_m_weight = torch.tensor(
701            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
702        )
703        self.assertEqual(expected_m_weight, m.weight)
704
705        expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype)
706        self.assertEqual(expected_n_weight, n.weight)
707
708    def test_custom_from_mask_pruning(self):
709        r"""Test that the CustomFromMask is capable of receiving
710        as input at instantiation time a custom mask, and combining it with
711        the previous default mask to generate the correct final mask.
712        """
713        # new mask
714        mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]])
715        # old mask
716        default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]])
717
718        # some tensor (not actually used)
719        t = torch.rand_like(mask.to(dtype=torch.float32))
720
721        p = prune.CustomFromMask(mask=mask)
722
723        computed_mask = p.compute_mask(t, default_mask)
724        expected_mask = torch.tensor(
725            [[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype
726        )
727
728        self.assertEqual(computed_mask, expected_mask)
729
730    def test_pruning_rollback(self):
731        r"""Test that if something fails when the we try to compute the mask,
732        then the model isn't left in some intermediate half-pruned state.
733        The try/except statement in `apply` should handle rolling back
734        to the previous state before pruning began.
735        """
736        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
737        names = ["weight", "bias"]
738
739        for m in modules:
740            for name in names:
741                with self.subTest(m=m, name=name):
742                    with mock.patch(
743                        "torch.nn.utils.prune.L1Unstructured.compute_mask"
744                    ) as compute_mask:
745                        compute_mask.side_effect = Exception("HA!")
746                        with self.assertRaises(Exception):
747                            prune.l1_unstructured(m, name=name, amount=0.9)
748
749                        self.assertTrue(name in dict(m.named_parameters()))
750                        self.assertFalse(name + "_mask" in dict(m.named_buffers()))
751                        self.assertFalse(name + "_orig" in dict(m.named_parameters()))
752
753    def test_pruning_serialization_model(self):
754        # create a model
755        model = torch.nn.Sequential(
756            torch.nn.Linear(10, 10),
757            torch.nn.ReLU(),
758            torch.nn.Linear(10, 1),
759        )
760        # check that everything looks normal before pruning
761        self.assertNotIn("0.weight_orig", model.state_dict())
762        self.assertNotIn("0.weight_mask", model.state_dict())
763        self.assertIn("0.weight", model.state_dict())
764
765        # prune one of its parameters
766        prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
767
768        # check that the original weight and the new mask are present
769        self.assertIn("0.weight_orig", model.state_dict())
770        self.assertIn("0.weight_mask", model.state_dict())
771        self.assertNotIn("0.weight", model.state_dict())
772        self.assertTrue(hasattr(model[0], "weight"))
773
774        pruned_weight = model[0].weight
775
776        with TemporaryFileName() as fname:
777            torch.save(model, fname)
778            # weights_only=False as this is legacy code that saves the model
779            new_model = torch.load(fname, weights_only=False)
780
781        # check that the original weight and the new mask are present
782        self.assertIn("0.weight_orig", new_model.state_dict())
783        self.assertIn("0.weight_mask", new_model.state_dict())
784        self.assertNotIn("0.weight", new_model.state_dict())
785        self.assertTrue(hasattr(new_model[0], "weight"))
786
787        self.assertEqual(pruned_weight, new_model[0].weight)
788
789    def test_pruning_serialization_state_dict(self):
790        # create a model
791        model = torch.nn.Sequential(
792            torch.nn.Linear(10, 10),
793            torch.nn.ReLU(),
794            torch.nn.Linear(10, 1),
795        )
796        # check that everything looks normal before pruning
797        self.assertNotIn("0.weight_orig", model.state_dict())
798        self.assertNotIn("0.weight_mask", model.state_dict())
799        self.assertIn("0.weight", model.state_dict())
800
801        # prune one of its parameters
802        prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
803
804        # check that the original weight and the new mask are present
805        self.assertIn("0.weight_orig", model.state_dict())
806        self.assertIn("0.weight_mask", model.state_dict())
807        self.assertNotIn("0.weight", model.state_dict())
808        self.assertTrue(hasattr(model[0], "weight"))
809
810        pruned_weight = model[0].weight
811
812        # make pruning permanent and restore parameter names as in base
813        # architecture
814        prune.remove(module=model[0], name="weight")
815
816        # check that the original weight and the new mask are no longer present
817        self.assertNotIn("0.weight_orig", model.state_dict())
818        self.assertNotIn("0.weight_mask", model.state_dict())
819        self.assertIn("0.weight", model.state_dict())
820
821        # save the state dict of model and reload it into new_model
822        new_model = torch.nn.Sequential(
823            torch.nn.Linear(10, 10),
824            torch.nn.ReLU(),
825            torch.nn.Linear(10, 1),
826        )
827        with TemporaryFileName() as fname:
828            torch.save(model.state_dict(), fname)
829            new_model.load_state_dict(torch.load(fname))
830
831        # check that the original weight and the new mask are not present in
832        # new_model either.
833        self.assertNotIn("0.weight_orig", new_model.state_dict())
834        self.assertNotIn("0.weight_mask", new_model.state_dict())
835        self.assertIn("0.weight", new_model.state_dict())
836
837        self.assertEqual(pruned_weight, new_model[0].weight)
838
839    def test_prune(self):
840        # create a new pruning method
841        p = prune.L1Unstructured(amount=2)
842        # create tensor to be pruned
843        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
844        # create prior mask by hand
845        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
846        # since we are pruning the two lowest magnitude units, the outcome of
847        # the calculation should be this:
848        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
849        pruned_tensor = p.prune(t, default_mask)
850        self.assertEqual(t * expected_mask, pruned_tensor)
851
852    def test_prune_importance_scores(self):
853        # create a new pruning method
854        p = prune.L1Unstructured(amount=2)
855        # create tensor to be pruned
856        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
857        importance_scores = torch.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to(
858            dtype=torch.float32
859        )
860        # create prior mask by hand
861        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
862        # since we are pruning the two lowest magnitude units, the outcome of
863        # the calculation should be this:
864        expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]])
865        pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores)
866        self.assertEqual(t * expected_mask, pruned_tensor)
867
868    def test_prune_importance_scores_mimic_default(self):
869        # create a new pruning method
870        p = prune.L1Unstructured(amount=2)
871        # create tensor to be pruned
872        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
873        # create prior mask by hand
874        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
875        # since we are pruning the two lowest magnitude units, the outcome of
876        # the calculation should be this:
877        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
878        pruned_tensor_without_importance_scores = p.prune(t, default_mask)
879        pruned_tensor_with_importance_scores = p.prune(
880            t, default_mask, importance_scores=t
881        )
882        self.assertEqual(
883            pruned_tensor_without_importance_scores,
884            pruned_tensor_with_importance_scores,
885        )
886        self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores)
887
888    def test_rnn_pruning(self):
889        l = torch.nn.LSTM(32, 32)
890        # This Module has 4 parameters called:
891        # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'
892
893        # Pruning one of them causes one of the weights to become a tensor
894        prune.l1_unstructured(l, "weight_ih_l0", 0.5)
895        assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 3
896
897        # Removing the pruning reparametrization restores the Parameter
898        prune.remove(l, "weight_ih_l0")
899        assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 4
900
901        # Make sure that, upon removal of the reparametrization, the
902        # `._parameters` and `.named_parameters` contain the right params.
903        # Specifically, the original weight ('weight_ih_l0') should be placed
904        # back in the parameters, while the reparametrization component
905        # ('weight_ih_l0_orig') should be removed.
906        assert "weight_ih_l0" in l._parameters
907        assert l._parameters["weight_ih_l0"] is not None
908        assert "weight_ih_l0_orig" not in l._parameters
909        assert "weight_ih_l0" in dict(l.named_parameters())
910        assert dict(l.named_parameters())["weight_ih_l0"] is not None
911        assert "weight_ih_l0_orig" not in dict(l.named_parameters())
912
913
914instantiate_parametrized_tests(TestPruningNN)
915
916if __name__ == "__main__":
917    run_tests()
918