xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_data_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import copy
4import itertools
5import logging
6import math
7from typing import Tuple
8
9import torch
10from torch import nn
11from torch.ao.pruning._experimental.data_sparsifier import (
12    BaseDataSparsifier,
13    DataNormSparsifier,
14)
15from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import (
16    post_training_sparse_quantize,
17)
18from torch.nn.utils.parametrize import is_parametrized
19from torch.testing._internal.common_utils import TestCase
20
21
22logging.basicConfig(
23    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
24)
25
26
27class ImplementedSparsifier(BaseDataSparsifier):
28    def __init__(self, **kwargs):
29        super().__init__(**kwargs)
30
31    def update_mask(self, name, data, **kwargs):
32        mask = self.get_mask(name)
33        mask[0] = 0
34        linear_state = self.state[name]
35        linear_state["step_count"] = linear_state.get("step_count", 0) + 1
36
37
38class _BaseDataSparsiferTestCase(TestCase):
39    r"""This helper test class takes in any supported type of and runs some tests.
40    The user is required to pass in the data that needs to sparsified and the
41    runner will run some tests that needs to be passed in order for the data
42    type to be supported.
43    TODO: Change the structure by creating a separate test case class for each
44          member function
45    """
46
47    def run_all_checks(self, data_list, data_with_config, defaults):
48        self.check_constructor(data_list, data_with_config, defaults)
49        self.check_squash_mask(data_list, data_with_config, defaults)
50        self.check_add_data(data_list, data_with_config, defaults)
51        self.check_step(data_list, data_with_config, defaults)
52        self.check_state_dict(data_list, data_with_config, defaults)
53        self.check_memory_reference(data_list, data_with_config, defaults)
54
55    @staticmethod
56    def _get_name_data_config(some_data, defaults=None):
57        if isinstance(some_data, Tuple):
58            # dealing with data_list
59            name, data = some_data
60            config = defaults
61        else:
62            # dealing with data_with_config
63            name, data, config = (
64                some_data["name"],
65                some_data["data"],
66                some_data["config"],
67            )
68        return name, data, config
69
70    @staticmethod
71    def _make_sparsifier(
72        data_list,
73        data_with_config,
74        defaults,
75        sparsifier_type=None,
76        sparsifier_kwargs=None,
77    ):
78        if sparsifier_type is None:
79            sparsifier = ImplementedSparsifier(data_list=data_list, **defaults)
80        else:
81            kwargs = copy.deepcopy(defaults)
82            kwargs.update(sparsifier_kwargs)
83            kwargs["data_list"] = data_list
84            sparsifier = sparsifier_type(**kwargs)
85        assert len(sparsifier.data_groups) == len(data_list)
86        for data_config_dict in data_with_config:
87            name, data, config = (
88                data_config_dict["name"],
89                data_config_dict["data"],
90                data_config_dict["config"],
91            )
92            sparsifier.add_data(name=name, data=data, **config)
93        return sparsifier
94
95    def check_constructor(self, data_list, data_with_config, defaults, **kwargs):
96        sparsifier = self._make_sparsifier(
97            data_list, data_with_config, defaults=defaults, **kwargs
98        )
99        self.assertEqual(
100            len(sparsifier.data_groups),
101            len(data_list) + len(data_with_config),
102            msg="Sparsifier data groups don't match the input "
103            f"({len(sparsifier.data_groups)} vs. "
104            f"{len(data_list) + len(data_with_config)}).",
105        )
106
107        all_data = data_list + data_with_config
108
109        for some_data in all_data:
110            name, _, config = self._get_name_data_config(some_data, defaults=defaults)
111            self.assertIn(name, sparsifier.data_groups)
112            self.assertEqual(sparsifier.data_groups[name], config)
113
114    def check_step(self, data_list, data_with_config, defaults, **kwargs):
115        sparsifier = self._make_sparsifier(
116            data_list, data_with_config, defaults=defaults, **kwargs
117        )
118        all_data = data_list + data_with_config
119
120        # Check data and mask before doing the step
121        for some_data in all_data:
122            name, data, _ = self._get_name_data_config(some_data)
123            data = sparsifier._extract_weight(data)
124            sparsified_data = sparsifier.get_data(name=name, return_original=False)
125            original_data = sparsifier.get_data(name=name, return_original=True)
126            mask = sparsifier.get_mask(name=name)
127            self.assertEqual(sparsified_data, data)
128            self.assertEqual(original_data, data)
129            self.assertEqualBroadcasting(mask[0], 1)
130
131        step_count = 3
132
133        for _ in range(0, step_count):
134            sparsifier.step()
135        for some_data in all_data:
136            name, data, _ = self._get_name_data_config(some_data)
137            data = sparsifier._extract_weight(data)
138            sparsified_data = sparsifier.get_data(name=name, return_original=False)
139            original_data = sparsifier.get_data(name=name, return_original=True)
140            mask = sparsifier.get_mask(name=name)
141            self.assertEqualBroadcasting(sparsified_data[0], 0)
142            self.assertEqual(original_data, data)
143            self.assertEqualBroadcasting(mask[0], 0)
144            assert "step_count" in sparsifier.state[name]
145            assert sparsifier.state[name]["step_count"] == 3
146
147    def check_squash_mask(self, data_list, data_with_config, defaults, **kwargs):
148        sparsifier = self._make_sparsifier(
149            data_list, data_with_config, defaults=defaults, **kwargs
150        )
151        all_data = data_list + data_with_config
152        for some_data in all_data:
153            name, _, _ = self._get_name_data_config(some_data)
154            assert hasattr(sparsifier._container, name)
155            assert is_parametrized(sparsifier._container, name)
156        sparsifier.step()
157        sparsifier.squash_mask()
158
159        for some_data in all_data:
160            name, _, _ = self._get_name_data_config(some_data)
161            assert not is_parametrized(
162                sparsifier._container, name
163            )  # not parametrized anymore
164            with self.assertRaises(ValueError):
165                sparsifier.get_data(name, return_original=True)
166
167    def check_add_data(self, data_list, data_with_config, defaults, **kwargs):
168        sparsifier = self._make_sparsifier(
169            data_list, data_with_config, defaults=defaults, **kwargs
170        )
171        all_data = data_list + data_with_config
172        for some_data in all_data:
173            name1, data1, config = self._get_name_data_config(
174                some_data, defaults=defaults
175            )
176            data1 = sparsifier._extract_weight(data1)
177            data1_old = copy.deepcopy(data1)
178            assert torch.all(data1 == sparsifier.get_data(name=name1))
179
180            sparsifier.step()
181            mask = sparsifier.get_mask(name1)
182
183            data2 = torch.randn(
184                data1.shape
185            )  # add another data with the same shape as original data
186            sparsifier.add_data(name=name1, data=data2)
187            assert torch.all(data2 == sparsifier.get_data(name=name1))
188
189            assert torch.all(
190                sparsifier.get_mask(name1) == mask
191            )  # mask should not change
192            assert torch.all(data1_old == data1)
193
194            assert (
195                sparsifier.data_groups[name1] == config
196            )  # if replaced old_config should match new config
197
198    def check_state_dict(self, data_list, data_with_config, defaults, **kwargs):
199        sparsifier1 = self._make_sparsifier(
200            data_list, data_with_config, defaults=defaults, **kwargs
201        )
202        sparsifier2 = self._make_sparsifier(
203            data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs
204        )
205        sparsifier1.step()
206
207        state_dict1 = sparsifier1.state_dict()
208
209        assert sparsifier1.state != sparsifier2.state
210        name, _, _ = self._get_name_data_config(data_list[0])
211        self.assertNotEqual(sparsifier1.get_mask(name), sparsifier2.get_mask(name))
212
213        sparsifier2.load_state_dict(state_dict1)
214        assert len(sparsifier1.state) == len(sparsifier2.state)
215        assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
216
217        state1 = state_dict1["state"]
218        for name in state1.keys():
219            # compare mask
220            assert name in sparsifier2.state
221            assert "mask" in sparsifier2.state[name]
222            assert "mask" in sparsifier1.state[name]
223            mask1, mask2 = state1[name]["mask"], sparsifier2.state[name]["mask"]
224            assert mask1.is_sparse and not mask2.is_sparse
225            assert torch.all(
226                mask1.to_dense() == mask2
227            )  # mask1 is stored as sparse coo now
228
229            # compare data_groups
230            dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups
231            assert name in dg1 and name in dg2
232            assert dg1[name] == dg2[name]
233
234            # compare container
235            container1, container2 = sparsifier1._container, sparsifier2._container
236            assert torch.all(getattr(container1, name) == getattr(container2, name))
237            assert is_parametrized(container1, name) == is_parametrized(
238                container2, name
239            )
240            if is_parametrized(container1, name):
241                param1 = getattr(container1.parametrizations, name)[0]
242                param2 = getattr(container2.parametrizations, name)[0]
243                assert hasattr(param1, "mask")
244                assert hasattr(param2, "mask")
245                self.assertEqual(param1.__dict__, param2.__dict__)
246
247    def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs):
248        """Checks if the data is truly "attached" to the sparsifier. Meaning, when the
249        data is changed outside of the sparsifier, the changes must be reflected on the data
250        inside the data sparsifier as well.
251        This makes sure that the sparsifier is holding the memory reference of the data and
252        not copies.
253
254        This test modifies the data and asserts that data in the sparsifier is changed as well
255        """
256        sparsifier = self._make_sparsifier(
257            data_list, data_with_config, defaults=defaults, **kwargs
258        )
259        all_data = data_list + data_with_config
260        for some_data in all_data:
261            name, data, _ = self._get_name_data_config(some_data)
262            weight = sparsifier._extract_weight(data)
263            weight.data = weight + torch.randn(*weight.shape)
264            contained_data = sparsifier.get_data(name=name)
265            assert (
266                weight.data.storage().data_ptr()
267                == contained_data.data.storage().data_ptr()
268            )
269            assert torch.all(contained_data == weight)
270
271
272class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
273    r"""This helper test class takes in any supported type of and runs some tests.
274    This inherits the TestBaseDataSparsifierRuner wherein some functions are
275    over-ridden to take accomodate the specific sparsifier.
276    TODO: Change the structure by creating a separate test case class for each
277          member function
278    """
279
280    def run_all_checks(self, data_list, defaults, data_with_config, norm_type="L1"):
281        assert norm_type in ["L1", "L2"]
282        kwargs = {
283            "sparsifier_type": DataNormSparsifier,
284            "sparsifier_kwargs": {"norm": norm_type},
285        }
286        self.check_constructor(data_list, data_with_config, defaults, **kwargs)
287        self.check_squash_mask(data_list, data_with_config, defaults, **kwargs)
288        self.check_add_data(data_list, data_with_config, defaults, **kwargs)
289        self.check_state_dict(data_list, data_with_config, defaults, **kwargs)
290        self.check_step(data_list, data_with_config, defaults, norm_type=norm_type)
291        self.check_step_2_of_4(norm_type=norm_type)
292        self.check_sparsity_level(
293            data_list, data_with_config, defaults, norm_type=norm_type
294        )
295        self.check_memory_reference(data_list, data_with_config, defaults, **kwargs)
296
297    @staticmethod
298    def _get_bounds_on_actual_sparsity(config, tensor_shape):
299        r"""This function gets the bounds on actual sparsity.
300        Note::
301            Although we specify the sparsity_level parameter, this does not mean that
302            the actual sparsity obtained after sparsification is the same as sparsity_level.
303            The actual sparsity depends largely on the shape and the data itself.
304        """
305        sparsity_level = config["sparsity_level"]
306        zeros_per_block = config["zeros_per_block"]
307        sparse_block_shape = config["sparse_block_shape"]
308
309        height, width = tensor_shape[-2], tensor_shape[-1]
310        block_height, block_width = sparse_block_shape
311        number_blocks = math.ceil(height / block_height) * math.ceil(
312            width / block_width
313        )
314        values_per_block = block_height * block_width
315
316        if zeros_per_block == 0:
317            return (1.0, 1.0)
318        else:
319            # min value assumes zeros_per_block is 1
320            min_values_sparsified = round(number_blocks * sparsity_level)
321            # max value assumes actual zeros_per_block
322            max_values_sparsified = min_values_sparsified * min(
323                values_per_block, zeros_per_block
324            )
325            lower_bound = min_values_sparsified / (height * width)
326            upper_bound = min(1.0, max_values_sparsified / (height * width))
327
328            lower_bound, upper_bound = round(lower_bound, 3), round(upper_bound, 3)
329            return lower_bound, upper_bound
330
331    def check_step(self, data_list, data_with_config, defaults, norm_type="L1"):
332        sparsifier = self._make_sparsifier(
333            data_list,
334            data_with_config,
335            defaults,
336            sparsifier_type=DataNormSparsifier,
337            sparsifier_kwargs={"norm": norm_type},
338        )
339        all_data = data_list + data_with_config
340
341        # mask before step() should not be sparsified
342        for some_data in all_data:
343            name, _, _ = self._get_name_data_config(some_data)
344            mask = sparsifier.get_mask(name=name)
345            assert (1.0 - mask.mean()) == 0  # checking sparsity level is 0
346
347        sparsifier.step()
348
349        for some_data in all_data:
350            name, _, _ = self._get_name_data_config(some_data)
351            mask = sparsifier.get_mask(name=name)
352            config = sparsifier.data_groups[name]
353            lb, ub = self._get_bounds_on_actual_sparsity(config, mask.shape)
354            mask = mask.to(torch.float)
355            actual_sparsity = round(1 - mask.mean().item(), 3)
356            assert actual_sparsity >= lb and actual_sparsity <= ub
357            assert (
358                actual_sparsity > 0.0
359            )  # exact sparsity level cannot be achieved due to size of tensor
360
361        iters_before_collapse = 100
362
363        test_sparsifier = DataNormSparsifier(
364            sparsity_level=0.5,
365            sparse_block_shape=(1, 4),
366            zeros_per_block=4,
367            norm=norm_type,
368        )
369
370        for _ in range(iters_before_collapse):
371            new_data = torch.randn(20, 20)
372            test_sparsifier.add_data(name="test_data", data=new_data)
373            test_sparsifier.step()
374            mask = test_sparsifier.get_mask(name="test_data")
375            mask = mask.to(torch.float)
376            assert (1.0 - mask.mean().item()) > 0  # some sparsity achieved
377
378    def check_step_2_of_4(self, norm_type):
379        # overriding default config for test purposes
380        default_config = {
381            "sparsity_level": 1.0,
382            "zeros_per_block": 2,
383            "sparse_block_shape": (1, 4),
384        }
385        data_list = [("test_data", torch.randn(4, 4))]
386
387        sparsifier = DataNormSparsifier(
388            data_list=data_list, norm=norm_type, **default_config
389        )
390        sparsifier.step()
391
392        for some_data in data_list:
393            name, _ = some_data
394            mask = sparsifier.get_mask(name=name)
395            mask = mask.to(torch.float)
396            self.assertAlmostEqual(1.0 - mask.mean().item(), 0.5, places=2)
397            for row in mask:
398                for idx in range(0, len(row), 4):
399                    block = row[idx : idx + 4]
400                    block, _ = block.sort()
401                    assert (block[:2] == 0).all()
402                    assert (block[2:] != 0).all()
403
404    def check_sparsity_level(
405        self, data_list, data_with_config, defaults, norm_type="L1"
406    ):
407        sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
408        sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
409        zeros_per_blocks = [0, 1, 2, 3, 4]
410        sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type)
411
412        testcases = itertools.tee(
413            itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)
414        )
415
416        assert (
417            len(data_with_config) > 0
418            and "name" in data_with_config[0]
419            and "data" in data_with_config[0]
420        )
421        # get some data
422        name, data = data_with_config[0]["name"], data_with_config[0]["data"]
423        for idx, (sl, sbs, zpb) in enumerate(testcases[0]):
424            new_name = f"{name}_{idx}"
425            if zpb > sbs[0] * sbs[1]:
426                continue
427            current_config = {
428                "sparsity_level": sl,
429                "sparse_block_shape": sbs,
430                "zeros_per_block": zpb,
431            }
432            sparsifier.add_data(name=new_name, data=data, **current_config)
433            if zpb > sbs[0] * sbs[1]:
434                continue
435
436        sparsifier.step()
437        sparsifier.squash_mask()
438        for idx, (sl, sbs, zpb) in enumerate(testcases[0]):
439            new_name = f"{name}_{idx}"
440            sparsified_data = sparsifier.get_data(name=new_name, original=False)
441            # sparse mask
442            sparse_mask = (sparsified_data == 0).float()
443            if zpb == 0:
444                assert sparse_mask.mean() == 0
445            else:
446                # Ratio of individual zeros in the tensor
447                true_sl = min(max(sl, 0.0), 1.0)
448                true_sl = true_sl * zpb / sbs[0] / sbs[1]
449                assert sparse_mask.mean() == true_sl
450
451
452class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
453    """To add unit tests to support new data types for the BaseDataSparsifier, create the following
454        data_list: List of tuples of name, data to be added to the constructor
455        defaults: default config for the above data in data_list
456        data_with_config: list of dictionaries defining name, data and config (look test_tensors())
457
458    Once the above is done, create an instance of TestBaseDataSparsifierType and call all the run_tests()
459    """
460
461    def test_tensors(self):
462        tensor1, tensor2, tensor3 = (
463            torch.randn(3, 3),
464            torch.randn(4, 4),
465            torch.randn(5, 5),
466        )
467        tensor4, tensor5 = torch.randn(1, 1), torch.randn(4, 4)
468        data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)]
469        defaults = {"test": 3}
470
471        data_with_config = [
472            {"name": "tensor4", "data": tensor4, "config": {"test": 7}},
473            {"name": "tensor5", "data": tensor5, "config": {"test": 8}},
474        ]
475        self.run_all_checks(
476            data_list=data_list, defaults=defaults, data_with_config=data_with_config
477        )
478
479    def test_nn_parameters(self):
480        param1, param2, param3 = (
481            nn.Parameter(torch.randn(3, 3)),
482            nn.Parameter(torch.randn(4, 4)),
483            nn.Parameter(torch.randn(5, 5)),
484        )
485        param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(
486            torch.randn(4, 4)
487        )
488        data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
489        defaults = {"test": 3}
490
491        data_with_config = [
492            {"name": "param4", "data": param4, "config": {"test": 7}},
493            {"name": "param5", "data": param5, "config": {"test": 8}},
494        ]
495        self.run_all_checks(
496            data_list=data_list, defaults=defaults, data_with_config=data_with_config
497        )
498
499    def test_nn_embeddings(self):
500        (
501            emb1,
502            emb2,
503        ) = nn.Embedding(
504            10, 3
505        ), nn.Embedding(20, 3)
506        emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
507
508        emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
509        data_list = [
510            ("emb1", emb1),
511            ("emb1_bag", emb1_bag),
512            ("emb2", emb2),
513            ("emb2_bag", emb2_bag),
514        ]
515        defaults = {"test": 3}
516
517        data_with_config = [
518            {"name": "emb3", "data": emb3, "config": {"test": 7}},
519            {"name": "emb3_bag", "data": emb3_bag, "config": {"test": 8}},
520        ]
521        self.run_all_checks(
522            data_list=data_list, defaults=defaults, data_with_config=data_with_config
523        )
524
525
526class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
527    """To add unit tests to support new data types for the NormDataSparsifier, create the following
528    data_list: List of tuples of name, data to be added to the constructor
529    defaults: default config for the above data in data_list
530    data_with_config: list of dictionaries defining name, data and config (look test_tensors())
531
532    Once the above is done, create an instance of _NormDataSparsifierTestRunner and call run_tests()
533    """
534
535    def test_tensors(self):
536        tensor1, tensor2, tensor3 = (
537            torch.randn(1, 10),
538            torch.randn(4, 4),
539            torch.randn(1, 5),
540        )
541        tensor4, tensor5 = torch.randn(1, 2), torch.randn(4, 4)
542        data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)]
543        defaults = {
544            "sparsity_level": 0.5,
545            "sparse_block_shape": (1, 4),
546            "zeros_per_block": 4,
547        }
548
549        data_with_config = [
550            {
551                "name": "tensor4",
552                "data": tensor4,
553                "config": {
554                    "sparsity_level": 0.7,
555                    "sparse_block_shape": (2, 3),
556                    "zeros_per_block": 6,
557                },
558            },
559            {
560                "name": "tensor5",
561                "data": tensor5,
562                "config": {
563                    "sparsity_level": 0.3,
564                    "sparse_block_shape": (2, 3),
565                    "zeros_per_block": 6,
566                },
567            },
568        ]
569        self.run_all_checks(
570            data_list=data_list,
571            defaults=defaults,
572            data_with_config=data_with_config,
573            norm_type="L1",
574        )
575        self.run_all_checks(
576            data_list=data_list,
577            defaults=defaults,
578            data_with_config=data_with_config,
579            norm_type="L2",
580        )
581
582    def test_nn_parameters(self):
583        param1, param2, param3 = (
584            nn.Parameter(torch.randn(1, 8)),
585            nn.Parameter(torch.randn(4, 4)),
586            nn.Parameter(torch.randn(5, 5)),
587        )
588        param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(
589            torch.randn(4, 4)
590        )
591        data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
592        defaults = {
593            "sparsity_level": 0.5,
594            "sparse_block_shape": (1, 4),
595            "zeros_per_block": 4,
596        }
597
598        data_with_config = [
599            {
600                "name": "param4",
601                "data": param4,
602                "config": {
603                    "sparsity_level": 0.7,
604                    "sparse_block_shape": (2, 3),
605                    "zeros_per_block": 6,
606                },
607            },
608            {
609                "name": "param5",
610                "data": param5,
611                "config": {
612                    "sparsity_level": 0.3,
613                    "sparse_block_shape": (2, 3),
614                    "zeros_per_block": 6,
615                },
616            },
617        ]
618        self.run_all_checks(
619            data_list=data_list,
620            defaults=defaults,
621            data_with_config=data_with_config,
622            norm_type="L1",
623        )
624        self.run_all_checks(
625            data_list=data_list,
626            defaults=defaults,
627            data_with_config=data_with_config,
628            norm_type="L2",
629        )
630
631    def test_nn_embeddings(self):
632        (
633            emb1,
634            emb2,
635        ) = nn.Embedding(
636            10, 3
637        ), nn.Embedding(20, 3)
638        emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
639
640        emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
641        data_list = [
642            ("emb1", emb1),
643            ("emb1_bag", emb1_bag),
644            ("emb2", emb2),
645            ("emb2_bag", emb2_bag),
646        ]
647        defaults = {
648            "sparsity_level": 0.5,
649            "sparse_block_shape": (1, 4),
650            "zeros_per_block": 4,
651        }
652
653        data_with_config = [
654            {
655                "name": "emb3",
656                "data": emb3,
657                "config": {
658                    "sparsity_level": 0.7,
659                    "sparse_block_shape": (2, 3),
660                    "zeros_per_block": 6,
661                },
662            },
663            {
664                "name": "emb3_bag",
665                "data": emb3_bag,
666                "config": {
667                    "sparsity_level": 0.3,
668                    "sparse_block_shape": (2, 3),
669                    "zeros_per_block": 6,
670                },
671            },
672        ]
673        self.run_all_checks(
674            data_list=data_list,
675            defaults=defaults,
676            data_with_config=data_with_config,
677            norm_type="L1",
678        )
679
680        self.run_all_checks(
681            data_list=data_list,
682            defaults=defaults,
683            data_with_config=data_with_config,
684            norm_type="L2",
685        )
686
687
688class Model(nn.Module):
689    def __init__(self) -> None:
690        super().__init__()
691        self.emb1 = nn.Embedding(100, 3)
692        self.embbag1 = nn.EmbeddingBag(200, 32)
693        self.emb_seq = nn.Sequential(nn.Embedding(150, 3), nn.EmbeddingBag(100, 3))
694        self.linear1 = nn.Linear(32, 32)
695        self.linear2 = nn.Linear(16, 16)
696
697
698class TestQuantizationUtils(TestCase):
699    def test_ptq_sparsify_first(self):
700        """The expectation is post_training_sparse_quantize function
701        1. Takes in a model
702        2. Sparsifies the embeddings
703        3. Quantize the embeddings
704
705        This unit test checks that
706        1. Embeddings and EmbeddingBags are sparsified to the right sparsity levels
707        2. Embeddings and EmbeddingBags are quantized
708        3. Linear modules are not quantized
709        """
710        model = Model()
711
712        sparse_config = {"sparsity_level": 0.80, "sparse_block_shape": (1, 1)}
713        select_embeddings = [model.embbag1, model.emb1]
714        post_training_sparse_quantize(
715            model,
716            data_sparsifier_class=DataNormSparsifier,
717            sparsify_first=True,
718            select_embeddings=select_embeddings,
719            **sparse_config,
720        )
721
722        assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
723        assert (
724            type(model.embbag1)
725            == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
726        )
727        assert type(model.emb_seq[0] == nn.Embedding)
728        assert type(model.emb_seq[1] == nn.EmbeddingBag)
729        assert type(model.linear1) == nn.Linear
730        assert type(model.linear2) == nn.Linear
731
732        dequant_emb1 = torch.dequantize(model.emb1.weight())
733        dequant_embbag1 = torch.dequantize(model.embbag1.weight())
734
735        threshold = 1e-2
736
737        sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean()
738        sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean()
739
740        assert abs(sl_emb1 - 0.80) <= 0.05  # +- 5% leeway
741        assert abs(sl_embbag1 - 0.80) <= 0.05  # +- 5% leeway
742
743    def test_ptq_quantize_first(self):
744        """The expectation is post_training_sparse_quantize function
745        1. Takes in a model
746        2. Quantize the embeddings
747        3. Sparsifies the embeddings
748
749        This unit test checks that
750        1. Embeddings and EmbeddingBags are sparsified to the right sparsity levels
751        2. Embeddings and EmbeddingBags are quantized
752        3. Linear modules are not quantized
753        """
754        model = Model()
755
756        sparse_config = {"sparsity_level": 0.8, "sparse_block_shape": (1, 1)}
757        post_training_sparse_quantize(
758            model, DataNormSparsifier, sparsify_first=False, **sparse_config
759        )
760
761        assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
762        assert (
763            type(model.embbag1)
764            == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
765        )
766        assert type(
767            model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding
768        )
769        assert type(
770            model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
771        )
772        assert type(model.linear1) == nn.Linear  # not quantized
773        assert type(model.linear2) == nn.Linear  # not quantized
774
775        dequant_emb1 = torch.dequantize(model.emb1.weight())
776        dequant_embbag1 = torch.dequantize(model.embbag1.weight())
777        dequant_emb_seq_0 = torch.dequantize(model.emb_seq[0].weight())
778        dequant_emb_seq_1 = torch.dequantize(model.emb_seq[1].weight())
779
780        # higher threshold as quantization occurs before sparsity
781        threshold = (
782            1  # zero points seem to have higher magnitude with sparsity occuring after
783        )
784
785        sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean()
786        sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean()
787        sl_emb_seq_0 = (torch.abs(dequant_emb_seq_0) < threshold).float().mean()
788        sl_emb_seq_1 = (torch.abs(dequant_emb_seq_1) < threshold).float().mean()
789
790        assert abs(sl_emb1 - 0.80) <= 0.05  # +- 5% leeway
791        assert abs(sl_embbag1 - 0.80) <= 0.05  # +- 5% leeway
792        assert abs(sl_emb_seq_0 - 0.80) <= 0.05  # +- 5% leeway
793        assert abs(sl_emb_seq_1 - 0.80) <= 0.05  # +- 5% leeway
794