xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import itertools
4import logging
5import re
6
7import torch
8from torch import nn
9from torch.ao.pruning import (
10    BaseSparsifier,
11    FakeSparsity,
12    NearlyDiagonalSparsifier,
13    WeightNormSparsifier,
14)
15from torch.nn.utils.parametrize import is_parametrized
16from torch.testing._internal.common_pruning import (
17    ImplementedSparsifier,
18    MockSparseLinear,
19    SimpleLinear,
20)
21from torch.testing._internal.common_utils import TestCase
22
23
24logging.basicConfig(
25    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
26)
27
28
29class TestBaseSparsifier(TestCase):
30    def test_constructor(self):
31        # Cannot instantiate the abstract base
32        self.assertRaises(TypeError, BaseSparsifier)
33        # Can instantiate the model with no configs
34        model = SimpleLinear()
35        sparsifier = ImplementedSparsifier(test=3)
36        sparsifier.prepare(model, config=None)
37        assert len(sparsifier.groups) == 5
38        sparsifier.step()
39        # Can instantiate the model with configs
40        sparsifier = ImplementedSparsifier(test=3)
41        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
42        assert len(sparsifier.groups) == 1
43        assert sparsifier.groups[0]["tensor_fqn"] == "linear1.weight"
44        assert "test" in sparsifier.groups[0]
45        assert sparsifier.groups[0]["test"] == 3
46
47    def test_prepare_config(self):
48        model = SimpleLinear()
49        sparsifier = ImplementedSparsifier(test=3)
50        # Make sure there are no parametrizations before `prepare`
51        assert not hasattr(model.seq[0], "parametrizations")
52        assert not hasattr(model.linear1, "parametrizations")
53        assert not hasattr(model.linear2, "parametrizations")
54        sparsifier.prepare(
55            model,
56            config=[
57                {"tensor_fqn": "seq.0.weight", "test": 42},
58                # No 'linear1' to make sure it will be skipped in the sparsification
59                {"tensor_fqn": "linear2.weight"},
60            ],
61        )
62        assert len(sparsifier.groups) == 2
63        # Check if default argument is not assigned if explicit
64        assert sparsifier.groups[0]["tensor_fqn"] == "seq.0.weight"
65        assert sparsifier.groups[0]["test"] == 42
66        # Check if FQN and module are pointing to the same location
67        assert sparsifier.groups[1]["tensor_fqn"] == "linear2.weight"
68        assert sparsifier.groups[1]["module"] == model.linear2
69        # Check if parameterizations are attached
70        assert hasattr(model.seq[0], "parametrizations")
71        assert not hasattr(model.linear1, "parametrizations")
72        assert hasattr(model.linear2, "parametrizations")
73
74    def test_step(self):
75        model = SimpleLinear()
76        sparsifier = ImplementedSparsifier(test=3)
77        sparsifier.enable_mask_update = True
78        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
79        sparsifier.step()
80        assert torch.all(model.linear1.parametrizations.weight[0].mask[0] == 0)
81
82    def test_state_dict(self):
83        step_count = 3
84        model0 = SimpleLinear()
85        sparsifier0 = ImplementedSparsifier(test=3)
86        sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}])
87        mask = model0.linear1.parametrizations["weight"][0].mask
88        mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
89        for step in range(step_count):
90            sparsifier0.step()
91        state_dict = sparsifier0.state_dict()
92
93        # Check the expected keys in the state_dict
94        assert "state" in state_dict
95        assert "step_count" in state_dict["state"]["linear1.weight"]
96        assert state_dict["state"]["linear1.weight"]["step_count"] == 3
97        assert "groups" in state_dict
98        assert "test" in state_dict["groups"][0]
99        assert "tensor_fqn" in state_dict["groups"][0]
100        assert state_dict["groups"][0]["tensor_fqn"] == "linear1.weight"
101
102        # Check loading static_dict creates an equivalent model
103        model1 = SimpleLinear()
104        sparsifier1 = ImplementedSparsifier()
105        sparsifier1.prepare(model1, None)
106
107        assert sparsifier0.state != sparsifier1.state
108
109        # Make sure the masks are different in the beginning
110        for mg in sparsifier0.groups:
111            if mg["tensor_fqn"] == "linear1.weight":
112                mask0 = mg["module"].parametrizations.weight[0].mask
113        for mg in sparsifier1.groups:
114            if mg["tensor_fqn"] == "linear1.weight":
115                mask1 = mg["module"].parametrizations.weight[0].mask
116        self.assertNotEqual(mask0, mask1)
117
118        sparsifier1.load_state_dict(state_dict)
119
120        # Make sure the states are loaded, and are correct
121        assert sparsifier0.state == sparsifier1.state
122
123        # Make sure the masks (and all dicts) are the same after loading
124        assert len(sparsifier0.groups) == len(sparsifier1.groups)
125        for idx in range(len(sparsifier0.groups)):
126            mg0 = sparsifier0.groups[idx]
127            mg1 = sparsifier1.groups[idx]
128            for key in mg0.keys():
129                assert key in mg1
130                if key == "module":
131                    # We cannot compare modules as they are different
132                    param0 = mg0[key].parametrizations.weight[0]
133                    param1 = mg1[key].parametrizations.weight[0]
134                    assert hasattr(param0, "mask")
135                    assert hasattr(param1, "mask")
136                    self.assertEqual(param0.__dict__, param1.__dict__)
137                else:
138                    assert mg0[key] == mg1[key]
139
140    def test_convert(self):
141        model = SimpleLinear()
142        sparsifier = ImplementedSparsifier(test=3)
143        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
144        new_model = sparsifier.convert(
145            model, mapping={nn.Linear: MockSparseLinear}, inplace=False
146        )
147
148        assert isinstance(new_model.linear1, MockSparseLinear)
149        assert isinstance(new_model.seq[0], nn.Linear)
150        assert isinstance(new_model.linear2, nn.Linear)
151
152    def test_mask_squash(self):
153        model = SimpleLinear()
154        sparsifier = ImplementedSparsifier(test=3)
155        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
156        assert hasattr(model.linear1.parametrizations.weight[0], "mask")
157        assert is_parametrized(model.linear1, "weight")
158        assert not is_parametrized(model.seq[0], "weight")
159
160        sparsifier.squash_mask()
161        assert not is_parametrized(model.seq[0], "weight")
162        assert not is_parametrized(model.linear1, "weight")
163
164    def test_mask_squash_with_params1(self):
165        model = SimpleLinear()
166        sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
167        sparsifier.prepare(
168            model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
169        )
170        sparsifier.squash_mask(
171            params_to_keep_per_layer={"linear1": ("foo", "bar"), "seq.0": ("baz",)}
172        )
173        assert not is_parametrized(model.seq[0], "weight")
174        assert not is_parametrized(model.linear1, "weight")
175        assert hasattr(model.seq[0], "sparse_params")
176        assert hasattr(model.linear1, "sparse_params")
177        assert model.seq[0].sparse_params.get("foo", None) is None
178        assert model.seq[0].sparse_params.get("bar", None) is None
179        assert model.seq[0].sparse_params.get("baz", None) == 1
180        assert model.linear1.sparse_params.get("foo", None) == 3
181        assert model.linear1.sparse_params.get("bar", None) == 2
182        assert model.linear1.sparse_params.get("baz", None) is None
183
184    def test_mask_squash_with_params2(self):
185        model = SimpleLinear()
186        sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
187        sparsifier.prepare(
188            model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
189        )
190        sparsifier.squash_mask(params_to_keep=("foo", "bar"))
191        assert not is_parametrized(model.seq[0], "weight")
192        assert not is_parametrized(model.linear1, "weight")
193        assert hasattr(model.seq[0], "sparse_params")
194        assert hasattr(model.linear1, "sparse_params")
195        assert model.seq[0].sparse_params.get("foo", None) == 3
196        assert model.seq[0].sparse_params.get("bar", None) == 2
197        assert model.seq[0].sparse_params.get("baz", None) is None
198        assert model.linear1.sparse_params.get("foo", None) == 3
199        assert model.linear1.sparse_params.get("bar", None) == 2
200        assert model.linear1.sparse_params.get("baz", None) is None
201
202    def test_mask_squash_with_params3(self):
203        model = SimpleLinear()
204        sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
205        sparsifier.prepare(
206            model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
207        )
208        sparsifier.squash_mask(
209            params_to_keep=("foo", "bar"), params_to_keep_per_layer={"seq.0": ("baz",)}
210        )
211        assert not is_parametrized(model.seq[0], "weight")
212        assert not is_parametrized(model.linear1, "weight")
213        assert hasattr(model.seq[0], "sparse_params")
214        assert hasattr(model.linear1, "sparse_params")
215        assert model.seq[0].sparse_params.get("foo", None) == 3
216        assert model.seq[0].sparse_params.get("bar", None) == 2
217        assert model.seq[0].sparse_params.get("baz", None) == 1
218        assert model.linear1.sparse_params.get("foo", None) == 3
219        assert model.linear1.sparse_params.get("bar", None) == 2
220        assert model.linear1.sparse_params.get("baz", None) is None
221
222
223class TestWeightNormSparsifier(TestCase):
224    def test_constructor(self):
225        model = SimpleLinear()
226        sparsifier = WeightNormSparsifier()
227        sparsifier.prepare(model, config=None)
228        for g in sparsifier.groups:
229            assert isinstance(g["module"], nn.Linear)
230            # The groups are unordered
231            assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")
232
233    def test_step(self):
234        model = SimpleLinear()
235        sparsifier = WeightNormSparsifier(sparsity_level=0.5)
236        sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
237        for g in sparsifier.groups:
238            # Before step
239            module = g["module"]
240            assert (
241                1.0 - module.parametrizations["weight"][0].mask.mean()
242            ) == 0  # checking sparsity level is 0
243        sparsifier.enable_mask_update = True
244        sparsifier.step()
245        self.assertAlmostEqual(
246            model.linear1.parametrizations["weight"][0].mask.mean().item(),
247            0.5,
248            places=2,
249        )
250        for g in sparsifier.groups:
251            # After step
252            module = g["module"]
253            assert (
254                1.0 - module.parametrizations["weight"][0].mask.mean()
255            ) > 0  # checking sparsity level has increased
256        # Test if the mask collapses to all zeros if the weights are randomized
257        iters_before_collapse = 1000
258        for _ in range(iters_before_collapse):
259            model.linear1.weight.data = torch.randn(model.linear1.weight.shape)
260            sparsifier.step()
261        for g in sparsifier.groups:
262            # After step
263            module = g["module"]
264            assert (
265                1.0 - module.parametrizations["weight"][0].mask.mean()
266            ) > 0  # checking sparsity level did not collapse
267
268    def test_step_2_of_4(self):
269        model = SimpleLinear()
270        sparsifier = WeightNormSparsifier(
271            sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2
272        )
273        sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
274        sparsifier.step()
275        # make sure the sparsity level is approximately 50%
276        mask = model.linear1.parametrizations["weight"][0].mask.to(
277            torch.float
278        )  # mean works on float only
279        self.assertAlmostEqual(mask.mean().item(), 0.5, places=2)
280        # Make sure each block has exactly 50% zeros
281        module = sparsifier.groups[0]["module"]
282        mask = module.parametrizations["weight"][0].mask
283        for row in mask:
284            for idx in range(0, len(row), 4):
285                block = row[idx : idx + 4]
286                block, _ = block.sort()
287                assert (block[:2] == 0).all()
288                assert (block[2:] != 0).all()
289
290    def test_prepare(self):
291        model = SimpleLinear()
292        sparsifier = WeightNormSparsifier()
293        sparsifier.prepare(model, config=None)
294        for g in sparsifier.groups:
295            module = g["module"]
296            # Check mask exists
297            assert hasattr(module.parametrizations["weight"][0], "mask")
298            # Check parametrization exists and is correct
299            assert is_parametrized(module, "weight")
300            assert type(module.parametrizations.weight[0]) == FakeSparsity
301
302    def test_mask_squash(self):
303        model = SimpleLinear()
304        sparsifier = WeightNormSparsifier()
305        sparsifier.prepare(model, config=None)
306        sparsifier.squash_mask()
307        for g in sparsifier.groups:
308            module = g["module"]
309            assert not is_parametrized(module, "weight")
310            assert not hasattr(module, "mask")
311
312    def test_sparsity_levels(self):
313        sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
314        sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
315        zeros_per_blocks = [0, 1, 2, 3, 4]
316
317        testcases = itertools.tee(
318            itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)
319        )
320        # Create a config and model with all the testcases
321        model = nn.Sequential()
322        sparsifier = WeightNormSparsifier()
323
324        sparsity_per_layer_config = []
325        p = re.compile(r"[-\.\s]")
326        for sl, sbs, zpb in testcases[0]:
327            # Make sure the number of zeros is not > values in a block
328            if zpb > sbs[0] * sbs[1]:
329                continue
330            layer_name = f"{sl}_{sbs}_{zpb}"
331            layer_name = p.sub("_", layer_name)
332
333            layer = nn.Linear(12, 12, bias=False)
334            layer.weight = nn.Parameter(torch.ones(12, 12))
335            model.add_module(layer_name, layer)
336            config = {
337                "tensor_fqn": layer_name + ".weight",
338                "sparsity_level": sl,
339                "sparse_block_shape": sbs,
340                "zeros_per_block": zpb,
341            }
342            sparsity_per_layer_config.append(config)
343
344        sparsifier.prepare(model, sparsity_per_layer_config)
345        sparsifier.step()
346        sparsifier.squash_mask()
347        model.eval()
348
349        for sl, sbs, zpb in testcases[1]:
350            if zpb > sbs[0] * sbs[1]:
351                continue
352            layer_name = f"{sl}_{sbs}_{zpb}"
353            layer_name = p.sub("_", layer_name)
354            layer = getattr(model, layer_name)
355
356            # Level of sparsity is achieved
357            sparse_mask = (layer.weight == 0).float()
358            if zpb == 0:
359                assert sparse_mask.mean() == 0
360            else:
361                # Ratio of individual zeros in the tensor
362                true_sl = min(max(sl, 0.0), 1.0)
363                true_sl = true_sl * zpb / sbs[0] / sbs[1]
364                assert sparse_mask.mean() == true_sl
365
366
367class TestNearlyDiagonalSparsifier(TestCase):
368    def test_constructor(self):
369        model = SimpleLinear()
370        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
371        sparsifier.prepare(model, config=None)
372        for g in sparsifier.groups:
373            assert isinstance(g["module"], nn.Linear)
374            # The groups are unordered
375            assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")
376
377    def test_step(self):
378        model = SimpleLinear()
379        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
380        sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
381
382        for g in sparsifier.groups:
383            # Before step
384            module = g["module"]
385            assert (
386                1.0 - module.parametrizations["weight"][0].mask.mean()
387            ) == 0  # checking sparsity level is 0
388
389        sparsifier.enable_mask_update = True
390        sparsifier.step()
391        mask = module.parametrizations["weight"][0].mask
392        height, width = mask.shape
393        assert torch.all(mask == torch.eye(height, width))
394
395        for g in sparsifier.groups:
396            # After step
397            module = g["module"]
398            assert (
399                1.0 - module.parametrizations["weight"][0].mask.mean()
400            ) > 0  # checking sparsity level has increased
401
402        # Test if the mask collapses to all zeros if the weights are randomized
403        iters_before_collapse = 1000
404        for _ in range(iters_before_collapse):
405            model.linear1.weight.data = torch.randn(model.linear1.weight.shape)
406            sparsifier.step()
407        for g in sparsifier.groups:
408            # After step
409            module = g["module"]
410            assert (
411                1.0 - module.parametrizations["weight"][0].mask.mean()
412            ) > 0  # checking sparsity level did not collapse
413
414    def test_prepare(self):
415        model = SimpleLinear()
416        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
417        sparsifier.prepare(model, config=None)
418        for g in sparsifier.groups:
419            module = g["module"]
420            # Check mask exists
421            assert hasattr(module.parametrizations["weight"][0], "mask")
422            # Check parametrization exists and is correct
423            assert is_parametrized(module, "weight")
424            assert type(module.parametrizations.weight[0]) == FakeSparsity
425
426    def test_mask_squash(self):
427        model = SimpleLinear()
428        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
429        sparsifier.prepare(model, config=None)
430        sparsifier.step()
431        sparsifier.squash_mask()
432        for g in sparsifier.groups:
433            module = g["module"]
434            assert not is_parametrized(module, "weight")
435            assert not hasattr(module, "mask")
436            weights = module.weight
437            height, width = weights.shape
438            assert torch.all(
439                weights == torch.eye(height, width) * weights
440            )  # only diagonal to be present
441
442    def test_sparsity_levels(self):
443        nearliness_levels = list(range(-1, 100))
444        model = nn.Sequential()
445
446        p = re.compile(r"[-\.\s]")
447        for nearliness in nearliness_levels:
448            sparsifier = NearlyDiagonalSparsifier(nearliness=1)
449            layer_name = f"{nearliness}"
450            layer_name = p.sub("_", layer_name)
451
452            layer = nn.Linear(32, 32, bias=False)
453            layer.weight = nn.Parameter(torch.ones(32, 32))
454            width, height = layer.weight.shape
455            model.add_module(layer_name, layer)
456            config = {"tensor_fqn": layer_name + ".weight", "nearliness": nearliness}
457
458            sparsifier.prepare(model, [config])
459            # should raise a ValueError when nearliness arg is illegal
460            if (nearliness > 0 and nearliness % 2 == 0) or (
461                nearliness // 2 >= min(width, height)
462            ):
463                with self.assertRaises(ValueError):
464                    sparsifier.step()
465            else:
466                sparsifier.step()
467                sparsifier.squash_mask()
468                model.eval()
469
470                layer = getattr(model, layer_name)
471                # verify that mask created corresponds to the nearliness
472                self._verify_nearliness(layer.weight, nearliness)
473
474    # helper function to verify nearliness of a mask
475    def _verify_nearliness(self, mask: torch.Tensor, nearliness: int):
476        if nearliness <= 0:
477            assert torch.all(mask == torch.zeros(mask.shape[0], mask.shape[1]))
478        else:
479            height, width = mask.shape
480            dist_to_diagonal = nearliness // 2
481            for row in range(0, height):
482                for col in range(0, width):
483                    if abs(row - col) <= dist_to_diagonal:
484                        assert mask[row, col] == 1
485                    else:
486                        assert mask[row, col] == 0
487