xref: /aosp_15_r20/external/pytorch/test/nn/test_pruning.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Workerimport pickle
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerimport unittest.mock as mock
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.prune as prune
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
11*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
12*da0073e9SAndroid Build Coastguard Worker    run_tests,
13*da0073e9SAndroid Build Coastguard Worker    TemporaryFileName,
14*da0073e9SAndroid Build Coastguard Worker    TEST_NUMPY,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerclass TestPruningNN(NNTestCase):
19*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
20*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    # torch/nn/utils/prune.py
23*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
24*da0073e9SAndroid Build Coastguard Worker    def test_validate_pruning_amount_init(self):
25*da0073e9SAndroid Build Coastguard Worker        r"""Test the first util function that validates the pruning
26*da0073e9SAndroid Build Coastguard Worker        amount requested by the user the moment the pruning method
27*da0073e9SAndroid Build Coastguard Worker        is initialized. This test checks that the expected errors are
28*da0073e9SAndroid Build Coastguard Worker        raised whenever the amount is invalid.
29*da0073e9SAndroid Build Coastguard Worker        The original function runs basic type checking + value range checks.
30*da0073e9SAndroid Build Coastguard Worker        It doesn't check the validity of the pruning amount with
31*da0073e9SAndroid Build Coastguard Worker        respect to the size of the tensor to prune. That's left to
32*da0073e9SAndroid Build Coastguard Worker        `_validate_pruning_amount`, tested below.
33*da0073e9SAndroid Build Coastguard Worker        """
34*da0073e9SAndroid Build Coastguard Worker        # neither float not int should raise TypeError
35*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
36*da0073e9SAndroid Build Coastguard Worker            prune._validate_pruning_amount_init(amount="I'm a string")
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        # float not in [0, 1] should raise ValueError
39*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
40*da0073e9SAndroid Build Coastguard Worker            prune._validate_pruning_amount_init(amount=1.1)
41*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
42*da0073e9SAndroid Build Coastguard Worker            prune._validate_pruning_amount_init(amount=20.0)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        # negative int should raise ValueError
45*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
46*da0073e9SAndroid Build Coastguard Worker            prune._validate_pruning_amount_init(amount=-10)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        # all these should pass without errors because they're valid amounts
49*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount_init(amount=0.34)
50*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount_init(amount=1500)
51*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount_init(amount=0)
52*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount_init(amount=0.0)
53*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount_init(amount=1)
54*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount_init(amount=1.0)
55*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(True)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
58*da0073e9SAndroid Build Coastguard Worker    def test_validate_pruning_amount(self):
59*da0073e9SAndroid Build Coastguard Worker        r"""Tests the second util function that validates the pruning
60*da0073e9SAndroid Build Coastguard Worker        amount requested by the user, this time with respect to the size
61*da0073e9SAndroid Build Coastguard Worker        of the tensor to prune. The rationale is that if the pruning amount,
62*da0073e9SAndroid Build Coastguard Worker        converted to absolute value of units to prune, is larger than
63*da0073e9SAndroid Build Coastguard Worker        the number of units in the tensor, then we expect the util function
64*da0073e9SAndroid Build Coastguard Worker        to raise a value error.
65*da0073e9SAndroid Build Coastguard Worker        """
66*da0073e9SAndroid Build Coastguard Worker        # if amount is int and amount > tensor_size, raise ValueError
67*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
68*da0073e9SAndroid Build Coastguard Worker            prune._validate_pruning_amount(amount=20, tensor_size=19)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        # amount is a float so this should not raise an error
71*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount(amount=0.3, tensor_size=0)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker        # this is okay
74*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount(amount=19, tensor_size=20)
75*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount(amount=0, tensor_size=0)
76*da0073e9SAndroid Build Coastguard Worker        prune._validate_pruning_amount(amount=1, tensor_size=1)
77*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(True)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
80*da0073e9SAndroid Build Coastguard Worker    def test_compute_nparams_to_prune(self):
81*da0073e9SAndroid Build Coastguard Worker        r"""Test that requested pruning `amount` gets translated into the
82*da0073e9SAndroid Build Coastguard Worker        correct absolute number of units to prune.
83*da0073e9SAndroid Build Coastguard Worker        """
84*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0)
85*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10)
86*da0073e9SAndroid Build Coastguard Worker        # if 1 is int, means 1 unit
87*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1)
88*da0073e9SAndroid Build Coastguard Worker        # if 1. is float, means 100% of units
89*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15)
90*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    def test_random_pruning_sizes(self):
93*da0073e9SAndroid Build Coastguard Worker        r"""Test that the new parameters and buffers created by the pruning
94*da0073e9SAndroid Build Coastguard Worker        method have the same size as the input tensor to prune. These, in
95*da0073e9SAndroid Build Coastguard Worker        fact, correspond to the pruned version of the tensor itself, its
96*da0073e9SAndroid Build Coastguard Worker        mask, and its original copy, so the size must match.
97*da0073e9SAndroid Build Coastguard Worker        """
98*da0073e9SAndroid Build Coastguard Worker        # fixturize test
99*da0073e9SAndroid Build Coastguard Worker        # TODO: add other modules
100*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
101*da0073e9SAndroid Build Coastguard Worker        names = ["weight", "bias"]
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        for m in modules:
104*da0073e9SAndroid Build Coastguard Worker            for name in names:
105*da0073e9SAndroid Build Coastguard Worker                with self.subTest(m=m, name=name):
106*da0073e9SAndroid Build Coastguard Worker                    original_tensor = getattr(m, name)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker                    prune.random_unstructured(m, name=name, amount=0.1)
109*da0073e9SAndroid Build Coastguard Worker                    # mask has the same size as tensor being pruned
110*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
111*da0073e9SAndroid Build Coastguard Worker                        original_tensor.size(), getattr(m, name + "_mask").size()
112*da0073e9SAndroid Build Coastguard Worker                    )
113*da0073e9SAndroid Build Coastguard Worker                    # 'orig' tensor has the same size as the original tensor
114*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
115*da0073e9SAndroid Build Coastguard Worker                        original_tensor.size(), getattr(m, name + "_orig").size()
116*da0073e9SAndroid Build Coastguard Worker                    )
117*da0073e9SAndroid Build Coastguard Worker                    # new tensor has the same size as the original tensor
118*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(original_tensor.size(), getattr(m, name).size())
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def test_random_pruning_orig(self):
121*da0073e9SAndroid Build Coastguard Worker        r"""Test that original tensor is correctly stored in 'orig'
122*da0073e9SAndroid Build Coastguard Worker        after pruning is applied. Important to make sure we don't
123*da0073e9SAndroid Build Coastguard Worker        lose info about the original unpruned parameter.
124*da0073e9SAndroid Build Coastguard Worker        """
125*da0073e9SAndroid Build Coastguard Worker        # fixturize test
126*da0073e9SAndroid Build Coastguard Worker        # TODO: add other modules
127*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
128*da0073e9SAndroid Build Coastguard Worker        names = ["weight", "bias"]
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        for m in modules:
131*da0073e9SAndroid Build Coastguard Worker            for name in names:
132*da0073e9SAndroid Build Coastguard Worker                with self.subTest(m=m, name=name):
133*da0073e9SAndroid Build Coastguard Worker                    # tensor prior to pruning
134*da0073e9SAndroid Build Coastguard Worker                    original_tensor = getattr(m, name)
135*da0073e9SAndroid Build Coastguard Worker                    prune.random_unstructured(m, name=name, amount=0.1)
136*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(original_tensor, getattr(m, name + "_orig"))
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    def test_random_pruning_new_weight(self):
139*da0073e9SAndroid Build Coastguard Worker        r"""Test that module.name now contains a pruned version of
140*da0073e9SAndroid Build Coastguard Worker        the original tensor obtained from multiplying it by the mask.
141*da0073e9SAndroid Build Coastguard Worker        """
142*da0073e9SAndroid Build Coastguard Worker        # fixturize test
143*da0073e9SAndroid Build Coastguard Worker        # TODO: add other modules
144*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
145*da0073e9SAndroid Build Coastguard Worker        names = ["weight", "bias"]
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        for m in modules:
148*da0073e9SAndroid Build Coastguard Worker            for name in names:
149*da0073e9SAndroid Build Coastguard Worker                with self.subTest(m=m, name=name):
150*da0073e9SAndroid Build Coastguard Worker                    # tensor prior to pruning
151*da0073e9SAndroid Build Coastguard Worker                    original_tensor = getattr(m, name)
152*da0073e9SAndroid Build Coastguard Worker                    prune.random_unstructured(m, name=name, amount=0.1)
153*da0073e9SAndroid Build Coastguard Worker                    # weight = weight_orig * weight_mask
154*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
155*da0073e9SAndroid Build Coastguard Worker                        getattr(m, name),
156*da0073e9SAndroid Build Coastguard Worker                        getattr(m, name + "_orig")
157*da0073e9SAndroid Build Coastguard Worker                        * getattr(m, name + "_mask").to(dtype=original_tensor.dtype),
158*da0073e9SAndroid Build Coastguard Worker                    )
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    def test_identity_pruning(self):
161*da0073e9SAndroid Build Coastguard Worker        r"""Test that a mask of 1s does not change forward or backward."""
162*da0073e9SAndroid Build Coastguard Worker        input_ = torch.ones(1, 5)
163*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 2)
164*da0073e9SAndroid Build Coastguard Worker        y_prepruning = m(input_)  # output prior to pruning
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        # compute grad pre-pruning and check it's equal to all ones
167*da0073e9SAndroid Build Coastguard Worker        y_prepruning.sum().backward()
168*da0073e9SAndroid Build Coastguard Worker        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!
169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
170*da0073e9SAndroid Build Coastguard Worker        old_grad_bias = m.bias.grad.clone()
171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        # remove grads
174*da0073e9SAndroid Build Coastguard Worker        m.zero_grad()
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        # force the mask to be made of all 1s
177*da0073e9SAndroid Build Coastguard Worker        prune.identity(m, name="weight")
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        # with mask of 1s, output should be identical to no mask
180*da0073e9SAndroid Build Coastguard Worker        y_postpruning = m(input_)
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_prepruning, y_postpruning)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        # with mask of 1s, grad should be identical to no mask
184*da0073e9SAndroid Build Coastguard Worker        y_postpruning.sum().backward()
185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_weight, m.weight_orig.grad)
186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_bias, m.bias.grad)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        # calling forward twice in a row shouldn't change output
189*da0073e9SAndroid Build Coastguard Worker        y1 = m(input_)
190*da0073e9SAndroid Build Coastguard Worker        y2 = m(input_)
191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    def test_random_pruning_0perc(self):
194*da0073e9SAndroid Build Coastguard Worker        r"""Test that a mask of 1s does not change forward or backward."""
195*da0073e9SAndroid Build Coastguard Worker        input_ = torch.ones(1, 5)
196*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 2)
197*da0073e9SAndroid Build Coastguard Worker        y_prepruning = m(input_)  # output prior to pruning
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        # compute grad pre-pruning and check it's equal to all ones
200*da0073e9SAndroid Build Coastguard Worker        y_prepruning.sum().backward()
201*da0073e9SAndroid Build Coastguard Worker        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!
202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
203*da0073e9SAndroid Build Coastguard Worker        old_grad_bias = m.bias.grad.clone()
204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        # remove grads
207*da0073e9SAndroid Build Coastguard Worker        m.zero_grad()
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        # force the mask to be made of all 1s
210*da0073e9SAndroid Build Coastguard Worker        with mock.patch(
211*da0073e9SAndroid Build Coastguard Worker            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
212*da0073e9SAndroid Build Coastguard Worker        ) as compute_mask:
213*da0073e9SAndroid Build Coastguard Worker            compute_mask.return_value = torch.ones_like(m.weight)
214*da0073e9SAndroid Build Coastguard Worker            prune.random_unstructured(
215*da0073e9SAndroid Build Coastguard Worker                m, name="weight", amount=0.9
216*da0073e9SAndroid Build Coastguard Worker            )  # amount won't count
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        # with mask of 1s, output should be identical to no mask
219*da0073e9SAndroid Build Coastguard Worker        y_postpruning = m(input_)
220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_prepruning, y_postpruning)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        # with mask of 1s, grad should be identical to no mask
223*da0073e9SAndroid Build Coastguard Worker        y_postpruning.sum().backward()
224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_weight, m.weight_orig.grad)
225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_grad_bias, m.bias.grad)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        # calling forward twice in a row shouldn't change output
228*da0073e9SAndroid Build Coastguard Worker        y1 = m(input_)
229*da0073e9SAndroid Build Coastguard Worker        y2 = m(input_)
230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    def test_random_pruning(self):
233*da0073e9SAndroid Build Coastguard Worker        input_ = torch.ones(1, 5)
234*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 2)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        # define custom mask to assign with mock
237*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones_like(m.weight)
238*da0073e9SAndroid Build Coastguard Worker        mask[1, 0] = 0
239*da0073e9SAndroid Build Coastguard Worker        mask[0, 3] = 0
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        # check grad is zero for masked weights
242*da0073e9SAndroid Build Coastguard Worker        with mock.patch(
243*da0073e9SAndroid Build Coastguard Worker            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
244*da0073e9SAndroid Build Coastguard Worker        ) as compute_mask:
245*da0073e9SAndroid Build Coastguard Worker            compute_mask.return_value = mask
246*da0073e9SAndroid Build Coastguard Worker            prune.random_unstructured(m, name="weight", amount=0.9)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        y_postpruning = m(input_)
249*da0073e9SAndroid Build Coastguard Worker        y_postpruning.sum().backward()
250*da0073e9SAndroid Build Coastguard Worker        # weight_orig is the parameter, so it's the tensor that will accumulate the grad
251*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.weight_orig.grad, mask)  # all 1s, except for masked units
252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.bias.grad, torch.ones_like(m.bias))
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        # make sure that weight_orig update doesn't modify [1, 0] and [0, 3]
255*da0073e9SAndroid Build Coastguard Worker        old_weight_orig = m.weight_orig.clone()
256*da0073e9SAndroid Build Coastguard Worker        # update weights
257*da0073e9SAndroid Build Coastguard Worker        learning_rate = 1.0
258*da0073e9SAndroid Build Coastguard Worker        for p in m.parameters():
259*da0073e9SAndroid Build Coastguard Worker            p.data.sub_(p.grad.data * learning_rate)
260*da0073e9SAndroid Build Coastguard Worker        # since these are pruned, they should not be updated
261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0])
262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3])
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    def test_random_pruning_forward(self):
265*da0073e9SAndroid Build Coastguard Worker        r"""check forward with mask (by hand)."""
266*da0073e9SAndroid Build Coastguard Worker        input_ = torch.ones(1, 5)
267*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 2)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        # define custom mask to assign with mock
270*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros_like(m.weight)
271*da0073e9SAndroid Build Coastguard Worker        mask[1, 0] = 1
272*da0073e9SAndroid Build Coastguard Worker        mask[0, 3] = 1
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        with mock.patch(
275*da0073e9SAndroid Build Coastguard Worker            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
276*da0073e9SAndroid Build Coastguard Worker        ) as compute_mask:
277*da0073e9SAndroid Build Coastguard Worker            compute_mask.return_value = mask
278*da0073e9SAndroid Build Coastguard Worker            prune.random_unstructured(m, name="weight", amount=0.9)
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        yhat = m(input_)
281*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0])
282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1])
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    def test_remove_pruning_forward(self):
285*da0073e9SAndroid Build Coastguard Worker        r"""Remove pruning and check forward is unchanged from previous
286*da0073e9SAndroid Build Coastguard Worker        pruned state.
287*da0073e9SAndroid Build Coastguard Worker        """
288*da0073e9SAndroid Build Coastguard Worker        input_ = torch.ones(1, 5)
289*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 2)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        # define custom mask to assign with mock
292*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones_like(m.weight)
293*da0073e9SAndroid Build Coastguard Worker        mask[1, 0] = 0
294*da0073e9SAndroid Build Coastguard Worker        mask[0, 3] = 0
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        # check grad is zero for masked weights
297*da0073e9SAndroid Build Coastguard Worker        with mock.patch(
298*da0073e9SAndroid Build Coastguard Worker            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
299*da0073e9SAndroid Build Coastguard Worker        ) as compute_mask:
300*da0073e9SAndroid Build Coastguard Worker            compute_mask.return_value = mask
301*da0073e9SAndroid Build Coastguard Worker            prune.random_unstructured(m, name="weight", amount=0.9)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        y_postpruning = m(input_)
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        prune.remove(m, "weight")
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        y_postremoval = m(input_)
308*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_postpruning, y_postremoval)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker    def test_pruning_id_consistency(self):
311*da0073e9SAndroid Build Coastguard Worker        r"""Test that pruning doesn't change the id of the parameters, which
312*da0073e9SAndroid Build Coastguard Worker        would otherwise introduce issues with pre-existing optimizers that
313*da0073e9SAndroid Build Coastguard Worker        point to old parameters.
314*da0073e9SAndroid Build Coastguard Worker        """
315*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 2, bias=False)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        tensor_id = id(next(iter(m.parameters())))
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        prune.random_unstructured(m, name="weight", amount=0.9)
320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        prune.remove(m, "weight")
323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    def test_random_pruning_pickle(self):
326*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
327*da0073e9SAndroid Build Coastguard Worker        names = ["weight", "bias"]
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        for m in modules:
330*da0073e9SAndroid Build Coastguard Worker            for name in names:
331*da0073e9SAndroid Build Coastguard Worker                with self.subTest(m=m, name=name):
332*da0073e9SAndroid Build Coastguard Worker                    prune.random_unstructured(m, name=name, amount=0.1)
333*da0073e9SAndroid Build Coastguard Worker                    m_new = pickle.loads(pickle.dumps(m))
334*da0073e9SAndroid Build Coastguard Worker                    self.assertIsInstance(m_new, type(m))
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    def test_multiple_pruning_calls(self):
337*da0073e9SAndroid Build Coastguard Worker        # if you call pruning twice, the hook becomes a PruningContainer
338*da0073e9SAndroid Build Coastguard Worker        m = nn.Conv3d(2, 2, 2)
339*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(m, name="weight", amount=0.1)
340*da0073e9SAndroid Build Coastguard Worker        weight_mask0 = m.weight_mask  # save it for later sanity check
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        # prune again
343*da0073e9SAndroid Build Coastguard Worker        prune.ln_structured(m, name="weight", amount=0.3, n=2, dim=0)
344*da0073e9SAndroid Build Coastguard Worker        hook = next(iter(m._forward_pre_hooks.values()))
345*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(hook, torch.nn.utils.prune.PruningContainer)
346*da0073e9SAndroid Build Coastguard Worker        # check that container._tensor_name is correctly set no matter how
347*da0073e9SAndroid Build Coastguard Worker        # many pruning methods are in the container
348*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hook._tensor_name, "weight")
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        # check that the pruning container has the right length
351*da0073e9SAndroid Build Coastguard Worker        # equal to the number of pruning iters
352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(hook), 2)  # m.weight has been pruned twice
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        # check that the entries of the pruning container are of the expected
355*da0073e9SAndroid Build Coastguard Worker        # type and in the expected order
356*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured)
357*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured)
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        # check that all entries that are 0 in the 1st mask are 0 in the
360*da0073e9SAndroid Build Coastguard Worker        # 2nd mask too
361*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0))
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        # prune again
364*da0073e9SAndroid Build Coastguard Worker        prune.ln_structured(m, name="weight", amount=0.1, n=float("inf"), dim=1)
365*da0073e9SAndroid Build Coastguard Worker        # check that container._tensor_name is correctly set no matter how
366*da0073e9SAndroid Build Coastguard Worker        # many pruning methods are in the container
367*da0073e9SAndroid Build Coastguard Worker        hook = next(iter(m._forward_pre_hooks.values()))
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hook._tensor_name, "weight")
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    def test_pruning_container(self):
371*da0073e9SAndroid Build Coastguard Worker        # create an empty container
372*da0073e9SAndroid Build Coastguard Worker        container = prune.PruningContainer()
373*da0073e9SAndroid Build Coastguard Worker        container._tensor_name = "test"
374*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(container), 0)
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        p = prune.L1Unstructured(amount=2)
377*da0073e9SAndroid Build Coastguard Worker        p._tensor_name = "test"
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        # test adding a pruning method to a container
380*da0073e9SAndroid Build Coastguard Worker        container.add_pruning_method(p)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        # test error raised if tensor name is different
383*da0073e9SAndroid Build Coastguard Worker        q = prune.L1Unstructured(amount=2)
384*da0073e9SAndroid Build Coastguard Worker        q._tensor_name = "another_test"
385*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
386*da0073e9SAndroid Build Coastguard Worker            container.add_pruning_method(q)
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        # test that adding a non-pruning method object to a pruning container
389*da0073e9SAndroid Build Coastguard Worker        # raises a TypeError
390*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
391*da0073e9SAndroid Build Coastguard Worker            container.add_pruning_method(10)
392*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
393*da0073e9SAndroid Build Coastguard Worker            container.add_pruning_method("ugh")
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    def test_pruning_container_compute_mask(self):
396*da0073e9SAndroid Build Coastguard Worker        r"""Test `compute_mask` of pruning container with a known `t` and
397*da0073e9SAndroid Build Coastguard Worker        `default_mask`. Indirectly checks that Ln structured pruning is
398*da0073e9SAndroid Build Coastguard Worker        acting on the right axis.
399*da0073e9SAndroid Build Coastguard Worker        """
400*da0073e9SAndroid Build Coastguard Worker        # create an empty container
401*da0073e9SAndroid Build Coastguard Worker        container = prune.PruningContainer()
402*da0073e9SAndroid Build Coastguard Worker        container._tensor_name = "test"
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        # 1) test unstructured pruning
405*da0073e9SAndroid Build Coastguard Worker        # create a new pruning method
406*da0073e9SAndroid Build Coastguard Worker        p = prune.L1Unstructured(amount=2)
407*da0073e9SAndroid Build Coastguard Worker        p._tensor_name = "test"
408*da0073e9SAndroid Build Coastguard Worker        # add the pruning method to the container
409*da0073e9SAndroid Build Coastguard Worker        container.add_pruning_method(p)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        # create tensor to be pruned
412*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
413*da0073e9SAndroid Build Coastguard Worker        # create prior mask by hand
414*da0073e9SAndroid Build Coastguard Worker        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
415*da0073e9SAndroid Build Coastguard Worker        # since we are pruning the two lowest magnitude units, the outcome of
416*da0073e9SAndroid Build Coastguard Worker        # the calculation should be this:
417*da0073e9SAndroid Build Coastguard Worker        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32)
418*da0073e9SAndroid Build Coastguard Worker        computed_mask = container.compute_mask(t, default_mask)
419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mask, computed_mask)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        # 2) test structured pruning
422*da0073e9SAndroid Build Coastguard Worker        q = prune.LnStructured(amount=1, n=2, dim=0)
423*da0073e9SAndroid Build Coastguard Worker        q._tensor_name = "test"
424*da0073e9SAndroid Build Coastguard Worker        container.add_pruning_method(q)
425*da0073e9SAndroid Build Coastguard Worker        # since we are pruning the lowest magnitude one of the two rows, the
426*da0073e9SAndroid Build Coastguard Worker        # outcome of the calculation should be this:
427*da0073e9SAndroid Build Coastguard Worker        expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32)
428*da0073e9SAndroid Build Coastguard Worker        computed_mask = container.compute_mask(t, default_mask)
429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mask, computed_mask)
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        # 2) test structured pruning, along another axis
432*da0073e9SAndroid Build Coastguard Worker        r = prune.LnStructured(amount=1, n=2, dim=1)
433*da0073e9SAndroid Build Coastguard Worker        r._tensor_name = "test"
434*da0073e9SAndroid Build Coastguard Worker        container.add_pruning_method(r)
435*da0073e9SAndroid Build Coastguard Worker        # since we are pruning the lowest magnitude of the four columns, the
436*da0073e9SAndroid Build Coastguard Worker        # outcome of the calculation should be this:
437*da0073e9SAndroid Build Coastguard Worker        expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32)
438*da0073e9SAndroid Build Coastguard Worker        computed_mask = container.compute_mask(t, default_mask)
439*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mask, computed_mask)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    def test_l1_unstructured_pruning(self):
442*da0073e9SAndroid Build Coastguard Worker        r"""Test that l1 unstructured pruning actually removes the lowest
443*da0073e9SAndroid Build Coastguard Worker        entries by l1 norm (by hand). It also checks that applying l1
444*da0073e9SAndroid Build Coastguard Worker        unstructured pruning more than once respects the previous mask.
445*da0073e9SAndroid Build Coastguard Worker        """
446*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 2)
447*da0073e9SAndroid Build Coastguard Worker        # modify its weight matrix by hand
448*da0073e9SAndroid Build Coastguard Worker        m.weight = torch.nn.Parameter(
449*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
450*da0073e9SAndroid Build Coastguard Worker        )
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(m, "weight", amount=2)
453*da0073e9SAndroid Build Coastguard Worker        expected_weight = torch.tensor(
454*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
455*da0073e9SAndroid Build Coastguard Worker        )
456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_weight, m.weight)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        # check that pruning again removes the next two smallest entries
459*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(m, "weight", amount=2)
460*da0073e9SAndroid Build Coastguard Worker        expected_weight = torch.tensor(
461*da0073e9SAndroid Build Coastguard Worker            [[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype
462*da0073e9SAndroid Build Coastguard Worker        )
463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_weight, m.weight)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    def test_l1_unstructured_pruning_with_importance_scores(self):
466*da0073e9SAndroid Build Coastguard Worker        r"""Test that l1 unstructured pruning actually removes the lowest
467*da0073e9SAndroid Build Coastguard Worker        entries of importance scores and not the parameter by l1 norm (by hand).
468*da0073e9SAndroid Build Coastguard Worker        It also checks that applying l1 unstructured pruning more than once
469*da0073e9SAndroid Build Coastguard Worker        respects the previous mask.
470*da0073e9SAndroid Build Coastguard Worker        """
471*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 2)
472*da0073e9SAndroid Build Coastguard Worker        # modify its weight matrix by hand
473*da0073e9SAndroid Build Coastguard Worker        m.weight = torch.nn.Parameter(
474*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
475*da0073e9SAndroid Build Coastguard Worker        )
476*da0073e9SAndroid Build Coastguard Worker        importance_scores = torch.tensor(
477*da0073e9SAndroid Build Coastguard Worker            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
478*da0073e9SAndroid Build Coastguard Worker        )
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(
481*da0073e9SAndroid Build Coastguard Worker            m, "weight", amount=2, importance_scores=importance_scores
482*da0073e9SAndroid Build Coastguard Worker        )
483*da0073e9SAndroid Build Coastguard Worker        expected_weight = torch.tensor(
484*da0073e9SAndroid Build Coastguard Worker            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
485*da0073e9SAndroid Build Coastguard Worker        )
486*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_weight, m.weight)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker        # check that pruning again removes two entries of m.weight that are colocated with
489*da0073e9SAndroid Build Coastguard Worker        # the next two smallest absolute values of importance scores.
490*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(
491*da0073e9SAndroid Build Coastguard Worker            m, "weight", amount=2, importance_scores=importance_scores
492*da0073e9SAndroid Build Coastguard Worker        )
493*da0073e9SAndroid Build Coastguard Worker        expected_weight = torch.tensor(
494*da0073e9SAndroid Build Coastguard Worker            [[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype
495*da0073e9SAndroid Build Coastguard Worker        )
496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_weight, m.weight)
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker    def test_unstructured_pruning_same_magnitude(self):
499*da0073e9SAndroid Build Coastguard Worker        r"""Since it may happen that the tensor to prune has entries with the
500*da0073e9SAndroid Build Coastguard Worker        same exact magnitude, it is important to check that pruning happens
501*da0073e9SAndroid Build Coastguard Worker        consistenly based on the bottom % of weights, and not by threshold,
502*da0073e9SAndroid Build Coastguard Worker        which would instead kill off *all* units with magnitude = threshold.
503*da0073e9SAndroid Build Coastguard Worker        """
504*da0073e9SAndroid Build Coastguard Worker        AMOUNT = 0.2
505*da0073e9SAndroid Build Coastguard Worker        p = prune.L1Unstructured(amount=AMOUNT)
506*da0073e9SAndroid Build Coastguard Worker        # create a random tensors with entries in {-2, 0, 2}
507*da0073e9SAndroid Build Coastguard Worker        t = 2 * torch.randint(low=-1, high=2, size=(10, 7))
508*da0073e9SAndroid Build Coastguard Worker        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement())
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker        computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
511*da0073e9SAndroid Build Coastguard Worker        nparams_pruned = torch.sum(computed_mask == 0)
512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nparams_toprune, nparams_pruned)
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker    def test_random_structured_pruning_amount(self):
515*da0073e9SAndroid Build Coastguard Worker        AMOUNT = 0.6
516*da0073e9SAndroid Build Coastguard Worker        AXIS = 2
517*da0073e9SAndroid Build Coastguard Worker        p = prune.RandomStructured(amount=AMOUNT, dim=AXIS)
518*da0073e9SAndroid Build Coastguard Worker        t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=torch.float32)
519*da0073e9SAndroid Build Coastguard Worker        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS])
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker        computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
522*da0073e9SAndroid Build Coastguard Worker        # check that 1 column is fully prune, the others are left untouched
523*da0073e9SAndroid Build Coastguard Worker        remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS]
524*da0073e9SAndroid Build Coastguard Worker        per_column_sums = sorted(torch.sum(computed_mask == 0, axis=remaining_axes))
525*da0073e9SAndroid Build Coastguard Worker        assert per_column_sums == [0, 20]
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker    def test_ln_structured_pruning(self):
528*da0073e9SAndroid Build Coastguard Worker        r"""Check Ln structured pruning by hand."""
529*da0073e9SAndroid Build Coastguard Worker        m = nn.Conv2d(3, 1, 2)
530*da0073e9SAndroid Build Coastguard Worker        m.weight.data = torch.tensor(
531*da0073e9SAndroid Build Coastguard Worker            [
532*da0073e9SAndroid Build Coastguard Worker                [
533*da0073e9SAndroid Build Coastguard Worker                    [[1.0, 2.0], [1.0, 2.5]],
534*da0073e9SAndroid Build Coastguard Worker                    [[0.5, 1.0], [0.1, 0.1]],
535*da0073e9SAndroid Build Coastguard Worker                    [[-3.0, -5.0], [0.1, -1.0]],
536*da0073e9SAndroid Build Coastguard Worker                ]
537*da0073e9SAndroid Build Coastguard Worker            ]
538*da0073e9SAndroid Build Coastguard Worker        )
539*da0073e9SAndroid Build Coastguard Worker        # expected effect of pruning 1 of the 3 channels by L2-norm
540*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis1 = torch.ones_like(m.weight)
541*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis1[:, 1] = 0.0
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        prune.ln_structured(m, "weight", amount=1, n=2, dim=1)
544*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mask_axis1, m.weight_mask)
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm
547*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis3 = expected_mask_axis1
548*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis3[:, :, :, 0] = 0.0
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker        prune.ln_structured(m, "weight", amount=1, n=1, dim=-1)
551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mask_axis3, m.weight_mask)
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker    def test_ln_structured_pruning_importance_scores(self):
554*da0073e9SAndroid Build Coastguard Worker        r"""Check Ln structured pruning by hand."""
555*da0073e9SAndroid Build Coastguard Worker        m = nn.Conv2d(3, 1, 2)
556*da0073e9SAndroid Build Coastguard Worker        m.weight.data = torch.tensor(
557*da0073e9SAndroid Build Coastguard Worker            [
558*da0073e9SAndroid Build Coastguard Worker                [
559*da0073e9SAndroid Build Coastguard Worker                    [[1.0, 2.0], [1.0, 2.5]],
560*da0073e9SAndroid Build Coastguard Worker                    [[0.5, 1.0], [0.1, 0.1]],
561*da0073e9SAndroid Build Coastguard Worker                    [[-3.0, -5.0], [0.1, -1.0]],
562*da0073e9SAndroid Build Coastguard Worker                ]
563*da0073e9SAndroid Build Coastguard Worker            ]
564*da0073e9SAndroid Build Coastguard Worker        )
565*da0073e9SAndroid Build Coastguard Worker        importance_scores = torch.tensor(
566*da0073e9SAndroid Build Coastguard Worker            [
567*da0073e9SAndroid Build Coastguard Worker                [
568*da0073e9SAndroid Build Coastguard Worker                    [[10.0, 1.0], [10.0, 1.0]],
569*da0073e9SAndroid Build Coastguard Worker                    [[30.0, 3.0], [30.0, 3.0]],
570*da0073e9SAndroid Build Coastguard Worker                    [[-20.0, -2.0], [-20.0, -2.0]],
571*da0073e9SAndroid Build Coastguard Worker                ]
572*da0073e9SAndroid Build Coastguard Worker            ]
573*da0073e9SAndroid Build Coastguard Worker        )
574*da0073e9SAndroid Build Coastguard Worker        # expected effect of pruning 1 of the 3 channels by L2-norm
575*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis1 = torch.ones_like(m.weight)
576*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis1[:, 0] = 0.0
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        prune.ln_structured(
579*da0073e9SAndroid Build Coastguard Worker            m, "weight", amount=1, n=2, dim=1, importance_scores=importance_scores
580*da0073e9SAndroid Build Coastguard Worker        )
581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mask_axis1, m.weight_mask)
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm
584*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis3 = expected_mask_axis1
585*da0073e9SAndroid Build Coastguard Worker        expected_mask_axis3[:, :, :, 1] = 0.0
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker        prune.ln_structured(
588*da0073e9SAndroid Build Coastguard Worker            m, "weight", amount=1, n=1, dim=-1, importance_scores=importance_scores
589*da0073e9SAndroid Build Coastguard Worker        )
590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mask_axis3, m.weight_mask)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    def test_remove_pruning(self):
593*da0073e9SAndroid Build Coastguard Worker        r"""`prune.remove` removes the hook and the reparametrization
594*da0073e9SAndroid Build Coastguard Worker        and makes the pruning final in the original parameter.
595*da0073e9SAndroid Build Coastguard Worker        """
596*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
597*da0073e9SAndroid Build Coastguard Worker        names = ["weight", "bias"]
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker        for m in modules:
600*da0073e9SAndroid Build Coastguard Worker            for name in names:
601*da0073e9SAndroid Build Coastguard Worker                with self.subTest(m=m, name=name):
602*da0073e9SAndroid Build Coastguard Worker                    # first prune
603*da0073e9SAndroid Build Coastguard Worker                    prune.random_unstructured(m, name, amount=0.5)
604*da0073e9SAndroid Build Coastguard Worker                    self.assertIn(name + "_orig", dict(m.named_parameters()))
605*da0073e9SAndroid Build Coastguard Worker                    self.assertIn(name + "_mask", dict(m.named_buffers()))
606*da0073e9SAndroid Build Coastguard Worker                    self.assertNotIn(name, dict(m.named_parameters()))
607*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(hasattr(m, name))
608*da0073e9SAndroid Build Coastguard Worker                    pruned_t = getattr(m, name)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker                    # then remove pruning
611*da0073e9SAndroid Build Coastguard Worker                    prune.remove(m, name)
612*da0073e9SAndroid Build Coastguard Worker                    self.assertIn(name, dict(m.named_parameters()))
613*da0073e9SAndroid Build Coastguard Worker                    self.assertNotIn(name + "_orig", dict(m.named_parameters()))
614*da0073e9SAndroid Build Coastguard Worker                    self.assertNotIn(name + "_mask", dict(m.named_buffers()))
615*da0073e9SAndroid Build Coastguard Worker                    final_t = getattr(m, name)
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(pruned_t, final_t)
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker    def test_remove_pruning_exception(self):
620*da0073e9SAndroid Build Coastguard Worker        r"""Removing from an unpruned tensor throws an assertion error"""
621*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
622*da0073e9SAndroid Build Coastguard Worker        names = ["weight", "bias"]
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker        for m in modules:
625*da0073e9SAndroid Build Coastguard Worker            for name in names:
626*da0073e9SAndroid Build Coastguard Worker                with self.subTest(m=m, name=name):
627*da0073e9SAndroid Build Coastguard Worker                    # check that the module isn't pruned
628*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(prune.is_pruned(m))
629*da0073e9SAndroid Build Coastguard Worker                    # since it isn't pruned, pruning can't be removed from it
630*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(ValueError):
631*da0073e9SAndroid Build Coastguard Worker                        prune.remove(m, name)
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    def test_global_pruning(self):
634*da0073e9SAndroid Build Coastguard Worker        r"""Test that global l1 unstructured pruning over 2 parameters removes
635*da0073e9SAndroid Build Coastguard Worker        the `amount=4` smallest global weights across the 2 parameters.
636*da0073e9SAndroid Build Coastguard Worker        """
637*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 2)
638*da0073e9SAndroid Build Coastguard Worker        n = nn.Linear(3, 1)
639*da0073e9SAndroid Build Coastguard Worker        # modify the weight matrices by hand
640*da0073e9SAndroid Build Coastguard Worker        m.weight = torch.nn.Parameter(
641*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
642*da0073e9SAndroid Build Coastguard Worker        )
643*da0073e9SAndroid Build Coastguard Worker        n.weight = torch.nn.Parameter(
644*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
645*da0073e9SAndroid Build Coastguard Worker        )
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker        params_to_prune = (
648*da0073e9SAndroid Build Coastguard Worker            (m, "weight"),
649*da0073e9SAndroid Build Coastguard Worker            (n, "weight"),
650*da0073e9SAndroid Build Coastguard Worker        )
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        # prune the 4 smallest weights globally by L1 magnitude
653*da0073e9SAndroid Build Coastguard Worker        prune.global_unstructured(
654*da0073e9SAndroid Build Coastguard Worker            params_to_prune, pruning_method=prune.L1Unstructured, amount=4
655*da0073e9SAndroid Build Coastguard Worker        )
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        expected_mweight = torch.tensor(
658*da0073e9SAndroid Build Coastguard Worker            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
659*da0073e9SAndroid Build Coastguard Worker        )
660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_mweight, m.weight)
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker        expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype)
663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_nweight, n.weight)
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    def test_global_pruning_importance_scores(self):
666*da0073e9SAndroid Build Coastguard Worker        r"""Test that global l1 unstructured pruning over 2 parameters removes
667*da0073e9SAndroid Build Coastguard Worker        the `amount=4` smallest global weights across the 2 parameters.
668*da0073e9SAndroid Build Coastguard Worker        """
669*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 2)
670*da0073e9SAndroid Build Coastguard Worker        n = nn.Linear(3, 1)
671*da0073e9SAndroid Build Coastguard Worker        # modify the weight matrices by hand
672*da0073e9SAndroid Build Coastguard Worker        m.weight = torch.nn.Parameter(
673*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
674*da0073e9SAndroid Build Coastguard Worker        )
675*da0073e9SAndroid Build Coastguard Worker        m_importance_scores = torch.tensor(
676*da0073e9SAndroid Build Coastguard Worker            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
677*da0073e9SAndroid Build Coastguard Worker        )
678*da0073e9SAndroid Build Coastguard Worker        n.weight = torch.nn.Parameter(
679*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
680*da0073e9SAndroid Build Coastguard Worker        )
681*da0073e9SAndroid Build Coastguard Worker        n_importance_scores = torch.tensor([[0, 10.0, -0.2]]).to(dtype=torch.float32)
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker        params_to_prune = (
684*da0073e9SAndroid Build Coastguard Worker            (m, "weight"),
685*da0073e9SAndroid Build Coastguard Worker            (n, "weight"),
686*da0073e9SAndroid Build Coastguard Worker        )
687*da0073e9SAndroid Build Coastguard Worker        importance_scores = {
688*da0073e9SAndroid Build Coastguard Worker            (m, "weight"): m_importance_scores,
689*da0073e9SAndroid Build Coastguard Worker            (n, "weight"): n_importance_scores,
690*da0073e9SAndroid Build Coastguard Worker        }
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        # prune the 4 smallest weights globally by L1 magnitude
693*da0073e9SAndroid Build Coastguard Worker        prune.global_unstructured(
694*da0073e9SAndroid Build Coastguard Worker            params_to_prune,
695*da0073e9SAndroid Build Coastguard Worker            pruning_method=prune.L1Unstructured,
696*da0073e9SAndroid Build Coastguard Worker            amount=4,
697*da0073e9SAndroid Build Coastguard Worker            importance_scores=importance_scores,
698*da0073e9SAndroid Build Coastguard Worker        )
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker        expected_m_weight = torch.tensor(
701*da0073e9SAndroid Build Coastguard Worker            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
702*da0073e9SAndroid Build Coastguard Worker        )
703*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_m_weight, m.weight)
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker        expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype)
706*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_n_weight, n.weight)
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker    def test_custom_from_mask_pruning(self):
709*da0073e9SAndroid Build Coastguard Worker        r"""Test that the CustomFromMask is capable of receiving
710*da0073e9SAndroid Build Coastguard Worker        as input at instantiation time a custom mask, and combining it with
711*da0073e9SAndroid Build Coastguard Worker        the previous default mask to generate the correct final mask.
712*da0073e9SAndroid Build Coastguard Worker        """
713*da0073e9SAndroid Build Coastguard Worker        # new mask
714*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]])
715*da0073e9SAndroid Build Coastguard Worker        # old mask
716*da0073e9SAndroid Build Coastguard Worker        default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]])
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker        # some tensor (not actually used)
719*da0073e9SAndroid Build Coastguard Worker        t = torch.rand_like(mask.to(dtype=torch.float32))
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker        p = prune.CustomFromMask(mask=mask)
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        computed_mask = p.compute_mask(t, default_mask)
724*da0073e9SAndroid Build Coastguard Worker        expected_mask = torch.tensor(
725*da0073e9SAndroid Build Coastguard Worker            [[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype
726*da0073e9SAndroid Build Coastguard Worker        )
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(computed_mask, expected_mask)
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    def test_pruning_rollback(self):
731*da0073e9SAndroid Build Coastguard Worker        r"""Test that if something fails when the we try to compute the mask,
732*da0073e9SAndroid Build Coastguard Worker        then the model isn't left in some intermediate half-pruned state.
733*da0073e9SAndroid Build Coastguard Worker        The try/except statement in `apply` should handle rolling back
734*da0073e9SAndroid Build Coastguard Worker        to the previous state before pruning began.
735*da0073e9SAndroid Build Coastguard Worker        """
736*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
737*da0073e9SAndroid Build Coastguard Worker        names = ["weight", "bias"]
738*da0073e9SAndroid Build Coastguard Worker
739*da0073e9SAndroid Build Coastguard Worker        for m in modules:
740*da0073e9SAndroid Build Coastguard Worker            for name in names:
741*da0073e9SAndroid Build Coastguard Worker                with self.subTest(m=m, name=name):
742*da0073e9SAndroid Build Coastguard Worker                    with mock.patch(
743*da0073e9SAndroid Build Coastguard Worker                        "torch.nn.utils.prune.L1Unstructured.compute_mask"
744*da0073e9SAndroid Build Coastguard Worker                    ) as compute_mask:
745*da0073e9SAndroid Build Coastguard Worker                        compute_mask.side_effect = Exception("HA!")
746*da0073e9SAndroid Build Coastguard Worker                        with self.assertRaises(Exception):
747*da0073e9SAndroid Build Coastguard Worker                            prune.l1_unstructured(m, name=name, amount=0.9)
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(name in dict(m.named_parameters()))
750*da0073e9SAndroid Build Coastguard Worker                        self.assertFalse(name + "_mask" in dict(m.named_buffers()))
751*da0073e9SAndroid Build Coastguard Worker                        self.assertFalse(name + "_orig" in dict(m.named_parameters()))
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker    def test_pruning_serialization_model(self):
754*da0073e9SAndroid Build Coastguard Worker        # create a model
755*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Sequential(
756*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
757*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
758*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 1),
759*da0073e9SAndroid Build Coastguard Worker        )
760*da0073e9SAndroid Build Coastguard Worker        # check that everything looks normal before pruning
761*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_orig", model.state_dict())
762*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_mask", model.state_dict())
763*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight", model.state_dict())
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        # prune one of its parameters
766*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker        # check that the original weight and the new mask are present
769*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight_orig", model.state_dict())
770*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight_mask", model.state_dict())
771*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight", model.state_dict())
772*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model[0], "weight"))
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker        pruned_weight = model[0].weight
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
777*da0073e9SAndroid Build Coastguard Worker            torch.save(model, fname)
778*da0073e9SAndroid Build Coastguard Worker            # weights_only=False as this is legacy code that saves the model
779*da0073e9SAndroid Build Coastguard Worker            new_model = torch.load(fname, weights_only=False)
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        # check that the original weight and the new mask are present
782*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight_orig", new_model.state_dict())
783*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight_mask", new_model.state_dict())
784*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight", new_model.state_dict())
785*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(new_model[0], "weight"))
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pruned_weight, new_model[0].weight)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker    def test_pruning_serialization_state_dict(self):
790*da0073e9SAndroid Build Coastguard Worker        # create a model
791*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Sequential(
792*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
793*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
794*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 1),
795*da0073e9SAndroid Build Coastguard Worker        )
796*da0073e9SAndroid Build Coastguard Worker        # check that everything looks normal before pruning
797*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_orig", model.state_dict())
798*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_mask", model.state_dict())
799*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight", model.state_dict())
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        # prune one of its parameters
802*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker        # check that the original weight and the new mask are present
805*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight_orig", model.state_dict())
806*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight_mask", model.state_dict())
807*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight", model.state_dict())
808*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(model[0], "weight"))
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker        pruned_weight = model[0].weight
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker        # make pruning permanent and restore parameter names as in base
813*da0073e9SAndroid Build Coastguard Worker        # architecture
814*da0073e9SAndroid Build Coastguard Worker        prune.remove(module=model[0], name="weight")
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker        # check that the original weight and the new mask are no longer present
817*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_orig", model.state_dict())
818*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_mask", model.state_dict())
819*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight", model.state_dict())
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker        # save the state dict of model and reload it into new_model
822*da0073e9SAndroid Build Coastguard Worker        new_model = torch.nn.Sequential(
823*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
824*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
825*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 1),
826*da0073e9SAndroid Build Coastguard Worker        )
827*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
828*da0073e9SAndroid Build Coastguard Worker            torch.save(model.state_dict(), fname)
829*da0073e9SAndroid Build Coastguard Worker            new_model.load_state_dict(torch.load(fname))
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        # check that the original weight and the new mask are not present in
832*da0073e9SAndroid Build Coastguard Worker        # new_model either.
833*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_orig", new_model.state_dict())
834*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("0.weight_mask", new_model.state_dict())
835*da0073e9SAndroid Build Coastguard Worker        self.assertIn("0.weight", new_model.state_dict())
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pruned_weight, new_model[0].weight)
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker    def test_prune(self):
840*da0073e9SAndroid Build Coastguard Worker        # create a new pruning method
841*da0073e9SAndroid Build Coastguard Worker        p = prune.L1Unstructured(amount=2)
842*da0073e9SAndroid Build Coastguard Worker        # create tensor to be pruned
843*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
844*da0073e9SAndroid Build Coastguard Worker        # create prior mask by hand
845*da0073e9SAndroid Build Coastguard Worker        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
846*da0073e9SAndroid Build Coastguard Worker        # since we are pruning the two lowest magnitude units, the outcome of
847*da0073e9SAndroid Build Coastguard Worker        # the calculation should be this:
848*da0073e9SAndroid Build Coastguard Worker        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
849*da0073e9SAndroid Build Coastguard Worker        pruned_tensor = p.prune(t, default_mask)
850*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t * expected_mask, pruned_tensor)
851*da0073e9SAndroid Build Coastguard Worker
852*da0073e9SAndroid Build Coastguard Worker    def test_prune_importance_scores(self):
853*da0073e9SAndroid Build Coastguard Worker        # create a new pruning method
854*da0073e9SAndroid Build Coastguard Worker        p = prune.L1Unstructured(amount=2)
855*da0073e9SAndroid Build Coastguard Worker        # create tensor to be pruned
856*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
857*da0073e9SAndroid Build Coastguard Worker        importance_scores = torch.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to(
858*da0073e9SAndroid Build Coastguard Worker            dtype=torch.float32
859*da0073e9SAndroid Build Coastguard Worker        )
860*da0073e9SAndroid Build Coastguard Worker        # create prior mask by hand
861*da0073e9SAndroid Build Coastguard Worker        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
862*da0073e9SAndroid Build Coastguard Worker        # since we are pruning the two lowest magnitude units, the outcome of
863*da0073e9SAndroid Build Coastguard Worker        # the calculation should be this:
864*da0073e9SAndroid Build Coastguard Worker        expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]])
865*da0073e9SAndroid Build Coastguard Worker        pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores)
866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t * expected_mask, pruned_tensor)
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker    def test_prune_importance_scores_mimic_default(self):
869*da0073e9SAndroid Build Coastguard Worker        # create a new pruning method
870*da0073e9SAndroid Build Coastguard Worker        p = prune.L1Unstructured(amount=2)
871*da0073e9SAndroid Build Coastguard Worker        # create tensor to be pruned
872*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
873*da0073e9SAndroid Build Coastguard Worker        # create prior mask by hand
874*da0073e9SAndroid Build Coastguard Worker        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
875*da0073e9SAndroid Build Coastguard Worker        # since we are pruning the two lowest magnitude units, the outcome of
876*da0073e9SAndroid Build Coastguard Worker        # the calculation should be this:
877*da0073e9SAndroid Build Coastguard Worker        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
878*da0073e9SAndroid Build Coastguard Worker        pruned_tensor_without_importance_scores = p.prune(t, default_mask)
879*da0073e9SAndroid Build Coastguard Worker        pruned_tensor_with_importance_scores = p.prune(
880*da0073e9SAndroid Build Coastguard Worker            t, default_mask, importance_scores=t
881*da0073e9SAndroid Build Coastguard Worker        )
882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
883*da0073e9SAndroid Build Coastguard Worker            pruned_tensor_without_importance_scores,
884*da0073e9SAndroid Build Coastguard Worker            pruned_tensor_with_importance_scores,
885*da0073e9SAndroid Build Coastguard Worker        )
886*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores)
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker    def test_rnn_pruning(self):
889*da0073e9SAndroid Build Coastguard Worker        l = torch.nn.LSTM(32, 32)
890*da0073e9SAndroid Build Coastguard Worker        # This Module has 4 parameters called:
891*da0073e9SAndroid Build Coastguard Worker        # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker        # Pruning one of them causes one of the weights to become a tensor
894*da0073e9SAndroid Build Coastguard Worker        prune.l1_unstructured(l, "weight_ih_l0", 0.5)
895*da0073e9SAndroid Build Coastguard Worker        assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 3
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker        # Removing the pruning reparametrization restores the Parameter
898*da0073e9SAndroid Build Coastguard Worker        prune.remove(l, "weight_ih_l0")
899*da0073e9SAndroid Build Coastguard Worker        assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 4
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker        # Make sure that, upon removal of the reparametrization, the
902*da0073e9SAndroid Build Coastguard Worker        # `._parameters` and `.named_parameters` contain the right params.
903*da0073e9SAndroid Build Coastguard Worker        # Specifically, the original weight ('weight_ih_l0') should be placed
904*da0073e9SAndroid Build Coastguard Worker        # back in the parameters, while the reparametrization component
905*da0073e9SAndroid Build Coastguard Worker        # ('weight_ih_l0_orig') should be removed.
906*da0073e9SAndroid Build Coastguard Worker        assert "weight_ih_l0" in l._parameters
907*da0073e9SAndroid Build Coastguard Worker        assert l._parameters["weight_ih_l0"] is not None
908*da0073e9SAndroid Build Coastguard Worker        assert "weight_ih_l0_orig" not in l._parameters
909*da0073e9SAndroid Build Coastguard Worker        assert "weight_ih_l0" in dict(l.named_parameters())
910*da0073e9SAndroid Build Coastguard Worker        assert dict(l.named_parameters())["weight_ih_l0"] is not None
911*da0073e9SAndroid Build Coastguard Worker        assert "weight_ih_l0_orig" not in dict(l.named_parameters())
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestPruningNN)
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
917*da0073e9SAndroid Build Coastguard Worker    run_tests()
918