xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_structured_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2import copy
3import logging
4import random
5
6import torch
7from torch import nn
8from torch.ao.pruning._experimental.pruner import (
9    BaseStructuredSparsifier,
10    FakeStructuredSparsity,
11    FPGMPruner,
12    LSTMSaliencyPruner,
13    SaliencyPruner,
14)
15from torch.nn.utils import parametrize
16from torch.testing._internal.common_pruning import (
17    Conv2dActivation,
18    Conv2dBias,
19    Conv2dPadBias,
20    Conv2dPool,
21    Conv2dPoolFlatten,
22    Conv2dPoolFlattenFunctional,
23    LinearActivation,
24    LinearActivationFunctional,
25    LinearBias,
26    LSTMLayerNormLinearModel,
27    LSTMLinearModel,
28    rows_are_subset,
29    SimpleConv2d,
30    SimpleLinear,
31)
32from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
33
34
35logging.basicConfig(
36    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
37)
38
39DEVICES = {
40    torch.device("cpu"),
41    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
42}
43
44
45class SimplePruner(BaseStructuredSparsifier):
46    def update_mask(self, module, tensor_name, **kwargs):
47        getattr(module.parametrizations, tensor_name)[0].mask[1] = False
48
49
50class ImplementedPruner(BaseStructuredSparsifier):
51    def update_mask(self, module, tensor_name, **kwargs):
52        """Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning"""
53        num_rows = len(module.parametrizations[tensor_name][0].mask)
54        prune = random.sample(list(range(num_rows)), num_rows // 3)
55        module.parametrizations[tensor_name][0].mask[prune] = False
56
57
58class BottomHalfLSTMPruner(BaseStructuredSparsifier):
59    """
60    Pruner that will remove the bottom half of the rows.
61    This is primarily meant for testing purposes
62    """
63
64    def update_mask(self, module, tensor_name, **kwargs):
65        for p in getattr(module.parametrizations, tensor_name):
66            if isinstance(p, FakeStructuredSparsity):
67                mask = p.mask
68                masks = torch.split(mask, len(mask) // 4)
69                for small in masks:
70                    num = len(small)
71                    small[num // 2 :] = False
72                new_mask = torch.cat(masks)
73                mask.data = new_mask.data
74
75
76class TestSaliencyPruner(TestCase):
77    def test_saliency_pruner_update_mask(self):
78        """Test that we prune out the row with the lowest saliency (first row)"""
79        model = SimpleLinear()
80        with torch.no_grad():
81            model.linear1.weight = nn.Parameter(
82                torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
83            )
84        pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}]
85        pruner = SaliencyPruner({})
86
87        pruner.prepare(model, pruning_config)
88        pruner.enable_mask_update = True
89        pruner.step()
90        pruned_model = pruner.prune()
91
92        expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]])
93        pruned = pruned_model.linear1.weight
94
95        assert expected.shape == pruned.shape
96        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
97
98    def test_lstm_saliency_pruner_update_mask(self):
99        model = LSTMLinearModel(
100            input_dim=2,
101            hidden_dim=2,
102            output_dim=2,
103            num_layers=1,
104        )
105
106        manual_weights = torch.Tensor(
107            [[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]]
108        )
109
110        with torch.no_grad():
111            model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
112            model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights))
113            model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0])
114            model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0])
115
116        config = [
117            {"tensor_fqn": "lstm.weight_ih_l0"},
118            {"tensor_fqn": "lstm.weight_hh_l0"},
119        ]
120        lstm_input = torch.ones((1, 2))
121        fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5})
122        fx_pruner.prepare(model, config)
123        fx_pruner.enable_mask_update = True
124        fx_pruner.step()
125
126        model.eval()
127        pruned_model = fx_pruner.prune()
128        pruned_model.eval()
129
130        # make sure both models run
131        model(lstm_input)
132        pruned_model(lstm_input)
133
134        # make sure lowest saliency rows are pruned
135        expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]])
136        pruned = model.lstm.weight_ih_l0
137        assert expected.shape == pruned.shape
138        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
139
140        expected = torch.Tensor([[2], [2], [-2], [-2]])
141        pruned = model.lstm.weight_hh_l0
142        assert expected.shape == pruned.shape
143        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
144
145        expected = torch.Tensor([2, 2, -2, -2])
146        for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]:
147            assert expected.shape == pruned.shape
148            assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
149
150
151class TestBaseStructuredSparsifier(TestCase):
152    def _check_pruner_prepared(self, model, pruner, device):
153        for config in pruner.groups:
154            module = config["module"]
155            assert module.weight.device.type == device.type
156            # Check mask exists
157            assert config["tensor_fqn"] in pruner.state
158            # Check parametrization exists and is correct
159            assert parametrize.is_parametrized(module)
160            assert hasattr(module, "parametrizations")
161            # Assume that this is the 1st/only parametrization
162            assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
163
164    def _check_pruner_valid_before_step(self, model, pruner, device):
165        for config in pruner.groups:
166            modules = []
167            if type(config["module"]) is tuple:
168                modules.extend(config["module"])
169            else:
170                module = config["module"]
171                modules.append(module)
172            for module in modules:
173                assert module.weight.device.type == device.type
174                assert module.parametrizations.weight[0].mask.dtype == torch.bool
175
176    def _check_pruner_valid_after_step(self, model, pruner, mask, device):
177        for config in pruner.groups:
178            modules = []
179            if type(config["module"]) is tuple:
180                modules.extend(config["module"])
181            else:
182                module = config["module"]
183                modules.append(module)
184            for module in modules:
185                assert module.weight.device.type == device.type
186                total = module.parametrizations.weight[0].mask.numel()
187                assert (
188                    module.parametrizations.weight[0].mask.count_nonzero()
189                    == total - mask
190                )
191
192    def _test_constructor_on_device(self, model, device):
193        self.assertRaisesRegex(
194            TypeError,
195            "BaseStructuredSparsifier.*update_mask",
196            BaseStructuredSparsifier,
197        )
198        model1 = copy.deepcopy(model).to(device)
199        pruner = SimplePruner(None)
200        pruner.prepare(model1, None)
201        pruner.enable_mask_update = True
202        for g in pruner.groups:
203            module = g["module"]
204            assert module.weight.device.type == device.type
205        assert len(pruner.groups) == 5
206        pruner.step()
207        # Can instantiate the model with configs
208        model2 = copy.deepcopy(model).to(device)
209        pruner = SimplePruner({"test": 3})
210        pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}])
211        assert len(pruner.groups) == 1
212        assert pruner.groups[0]["module_fqn"] == "seq.0"
213        assert "test" in pruner.groups[0]
214        assert pruner.groups[0]["test"] == 3
215
216    def test_constructor(self):
217        model = SimpleLinear()
218        for device in DEVICES:
219            self._test_constructor_on_device(model, torch.device(device))
220
221    def _test_prepare_linear_on_device(self, model, device):
222        model = copy.deepcopy(model).to(device)
223        x = torch.ones(128, 7, device=device)
224        pruner = SimplePruner(None)
225        pruner.prepare(model, None)
226        self._check_pruner_prepared(model, pruner, device)
227        assert model(x).shape == (128, 10)
228
229    def test_prepare_linear(self):
230        models = [
231            SimpleLinear(),
232            LinearBias(),
233            LinearActivation(),
234            LinearActivationFunctional(),
235        ]  # without and with bias
236        for device in DEVICES:
237            for model in models:
238                self._test_prepare_linear_on_device(model, torch.device(device))
239
240    def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device):
241        x = torch.ones((1, 1, 28, 28), device=device)
242        pruner = SimplePruner(None)
243        pruner.prepare(model, config)
244        self._check_pruner_prepared(model, pruner, device)
245        assert model(x).shape == expected_shape
246
247    def test_prepare_conv2d(self):
248        models = [
249            SimpleConv2d(),
250            Conv2dBias(),
251            Conv2dActivation(),
252            Conv2dPadBias(),
253            Conv2dPool(),
254        ]
255        shapes = [
256            (1, 52, 20, 20),
257            (1, 52, 18, 18),
258            (1, 52, 18, 18),
259            (1, 52, 24, 24),
260            (1, 52, 3, 3),
261        ]
262        configs = [None, None, None, None, None]
263        for device in DEVICES:
264            for model, shape, config in zip(models, shapes, configs):
265                model = model.to(device)
266                self._test_prepare_conv2d_on_device(
267                    model, shape, config, torch.device(device)
268                )
269
270    def _test_step_linear_on_device(self, model, device):
271        model = model.to(device)
272        x = torch.ones(7, 7, device=device)
273        pruner = SimplePruner(None)
274        pruner.prepare(model, None)
275        pruner.enable_mask_update = True
276        self._check_pruner_valid_before_step(model, pruner, device)
277        pruner.step()
278        self._check_pruner_valid_after_step(model, pruner, 1, device)
279
280    def test_step_linear(self):
281        models = [
282            SimpleLinear(),
283            LinearBias(),
284            LinearActivation(),
285            LinearActivationFunctional(),
286        ]
287        for device in DEVICES:
288            for model in models:
289                self._test_step_linear_on_device(model, torch.device(device))
290
291    def _test_step_conv2d_on_device(self, model, expected_shape, config, device):
292        model = model.to(device)
293        x = torch.ones((1, 1, 28, 28), device=device)
294        pruner = SimplePruner(None)
295        pruner.prepare(model, config)
296        pruner.enable_mask_update = True
297        self._check_pruner_valid_before_step(model, pruner, device)
298        pruner.step()
299        self._check_pruner_valid_after_step(model, pruner, 1, device)
300        assert model(x).shape == expected_shape
301
302    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
303    def test_step_conv2d(self):
304        models = [
305            SimpleConv2d(),
306            Conv2dBias(),
307            Conv2dActivation(),
308            Conv2dPadBias(),
309            Conv2dPool(),
310        ]
311        shapes = [
312            (1, 52, 20, 20),
313            (1, 52, 18, 18),
314            (1, 52, 18, 18),
315            (1, 52, 24, 24),
316            (1, 52, 3, 3),
317        ]
318        configs = [None, None, None, None, None]
319        for device in DEVICES:
320            for model, shape, config in zip(models, shapes, configs):
321                self._test_step_conv2d_on_device(
322                    model, shape, config, torch.device(device)
323                )
324
325    def _check_pruner_pruned(self, model, pruner, device):
326        for config in pruner.groups:
327            module = config["module"]
328            assert not hasattr(module, "parametrizations")
329            assert not hasattr(module, "mask")
330
331    def _test_linear_on_device(
332        self, model, config, expected_shape, device, also_prune_bias
333    ):
334        model = model.to(device)
335        model.eval()
336        num_original_params = sum(p.numel() for p in model.parameters())
337        x = torch.ones(128, 7, device=device)
338
339        pruner = ImplementedPruner({"prune_bias": also_prune_bias})
340        pruner.prepare(model, config)
341        pruner.enable_mask_update = True
342        pruner.step()
343
344        y_expected = model(x)
345
346        assert y_expected.shape == (128, 10)
347        self._check_pruner_prepared(model, pruner, device)
348
349        # Pruning step
350        pruned = pruner.prune()
351        y_pruned = pruned(x)
352        num_pruned_params = sum(p.numel() for p in pruned.parameters())
353
354        assert y_pruned.shape == expected_shape
355        self._check_pruner_pruned(model, pruner, device)
356        if y_pruned.shape == y_expected.shape:
357            assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all()
358            assert num_pruned_params < num_original_params
359
360    def test_prune_linear_linear(self):
361        r"""test pruning linear-> linear modules"""
362        configs, shapes = [], []
363        configs.append(
364            [
365                {"tensor_fqn": "seq.0.weight"},
366                {"tensor_fqn": "seq.1.weight"},
367                {"tensor_fqn": "seq.2.weight"},
368            ]
369        )
370        shapes.append((128, 10))
371
372        configs.append(
373            [
374                {"tensor_fqn": "seq.0.weight"},
375                {"tensor_fqn": "seq.1.weight"},
376                {"tensor_fqn": "seq.2.weight"},
377                {"tensor_fqn": "linear1.weight"},
378            ]
379        )
380        shapes.append((128, 10))
381
382        configs.append(
383            [
384                {"tensor_fqn": "seq.0.weight"},
385                {"tensor_fqn": "seq.2.weight"},
386            ]
387        )
388        shapes.append((128, 10))
389        for device in DEVICES:
390            for also_prune_bias in [True, False]:
391                for config, shape in zip(configs, shapes):
392                    self._test_linear_on_device(
393                        SimpleLinear(),
394                        config,
395                        shape,
396                        torch.device(device),
397                        also_prune_bias,
398                    )
399
400    def test_prune_linear_bias_linear(self):
401        # linear(bias) -> linear(no bias)
402        configs, shapes = [], []
403        configs.append(
404            [
405                {"tensor_fqn": "seq.0.weight"},
406                {"tensor_fqn": "seq.1.weight"},
407            ]
408        )
409        shapes.append((128, 10))
410
411        # linear(bias) -> linear(bias)
412        configs.append(
413            [
414                {"tensor_fqn": "seq.2.weight"},
415                {"tensor_fqn": "seq.3.weight"},
416            ]
417        )
418        shapes.append((128, 10))
419
420        # linear(no bias) -> linear(bias)
421        configs.append(
422            [
423                {"tensor_fqn": "seq.0.weight"},
424                {"tensor_fqn": "seq.1.weight"},
425                {"tensor_fqn": "seq.2.weight"},
426            ]
427        )
428        shapes.append((128, 10))
429
430        for device in DEVICES:
431            for also_prune_bias in [True, False]:
432                for config, shape in zip(configs, shapes):
433                    self._test_linear_on_device(
434                        LinearBias(),
435                        config,
436                        shape,
437                        torch.device(device),
438                        also_prune_bias,
439                    )
440
441    def test_prune_linear_activation_linear(self):
442        config = [
443            {"tensor_fqn": "seq.0.weight"},
444            {"tensor_fqn": "seq.2.weight"},
445            {"tensor_fqn": "seq.4.weight"},
446            {"tensor_fqn": "linear1.weight"},
447        ]
448        shape = (128, 10)
449
450        for device in DEVICES:
451            for also_prune_bias in [True, False]:
452                # test version with nn.Modules
453                self._test_linear_on_device(
454                    LinearActivation(),
455                    config,
456                    shape,
457                    torch.device(device),
458                    also_prune_bias,
459                )
460                # test functional version
461                self._test_linear_on_device(
462                    LinearActivationFunctional(),
463                    config,
464                    shape,
465                    torch.device(device),
466                    also_prune_bias,
467                )
468
469    def _test_conv2d_on_device(
470        self, model, config, x, expected_shape, device, also_prune_bias
471    ):
472        model = model.to(device)
473        num_original_params = sum(p.numel() for p in model.parameters())
474        model.eval()
475
476        pruner = ImplementedPruner({"prune_bias": also_prune_bias})
477        pruner.prepare(model, config)
478        pruner.enable_mask_update = True
479        pruner.step()
480
481        y_expected = model(x)
482        assert y_expected.shape == expected_shape
483
484        self._check_pruner_prepared(model, pruner, device)
485
486        # Fusion step
487        pruned = pruner.prune()
488        y_pruned = pruned(x)
489        num_pruned_params = sum(p.numel() for p in pruned.parameters())
490
491        assert y_pruned.shape == expected_shape
492        self._check_pruner_pruned(model, pruner, device)
493        if y_pruned.shape == y_expected.shape:
494            # TODO This rtol is a little high, need to double check if something specific is causing this to fail
495            assert torch.isclose(
496                y_expected,
497                y_pruned,
498                rtol=1e-3,
499                atol=1e-3,
500            ).all(), f"fail for {type(model)}"
501            # only time this should be equal is when all layers have padding and we can't prune
502            assert num_pruned_params <= num_original_params
503
504    def test_prune_conv2d_conv2d(self):
505        configs, shapes = [], []
506        # all within sequential blocks
507        configs.append(
508            [
509                {"tensor_fqn": "seq.0.weight"},
510            ]
511        )
512        shapes.append((1, 52, 20, 20))
513        # prune across sequential blocks
514        configs.append(
515            [
516                {"tensor_fqn": "seq.0.weight"},
517                {"tensor_fqn": "seq.1.weight"},
518                {"tensor_fqn": "conv2d1.weight"},
519            ]
520        )
521        shapes.append((1, 52, 20, 20))
522
523        for device in DEVICES:
524            x = torch.ones((1, 1, 28, 28), device=device)
525            for also_prune_bias in [True, False]:
526                for config, shape in zip(configs, shapes):
527                    self._test_conv2d_on_device(
528                        SimpleConv2d(),
529                        config,
530                        x,
531                        shape,
532                        torch.device(device),
533                        also_prune_bias,
534                    )
535
536    def test_prune_conv2d_bias_conv2d(self):
537        # Conv2d with Bias and no Activation
538        configs, shapes = [], []
539        # conv2d(bias) -> conv2d(bias)
540        configs.append(
541            [
542                {"tensor_fqn": "seq.0.weight"},
543                {"tensor_fqn": "seq.1.weight"},
544            ]
545        )
546        shapes.append((1, 52, 18, 18))
547
548        # conv2d(no bias) -> conv2d(bias)
549        configs.append(
550            [
551                {"tensor_fqn": "seq.0.weight"},
552                {"tensor_fqn": "seq.1.weight"},
553                {"tensor_fqn": "conv2d1.weight"},
554            ]
555        )
556        shapes.append((1, 52, 18, 18))
557
558        # conv2d(bias) -> conv2d(no bias)
559        configs.append(
560            [
561                {"tensor_fqn": "seq.0.weight"},
562                {"tensor_fqn": "seq.1.weight"},
563                {"tensor_fqn": "seq.2.weight"},
564            ]
565        )
566        shapes.append((1, 52, 18, 18))
567
568        for device in DEVICES:
569            x = torch.ones((1, 1, 28, 28), device=device)
570            for also_prune_bias in [True, False]:
571                for config, shape in zip(configs, shapes):
572                    self._test_conv2d_on_device(
573                        Conv2dBias(),
574                        config,
575                        x,
576                        shape,
577                        torch.device(device),
578                        also_prune_bias,
579                    )
580
581    def test_prune_conv2d_activation_conv2d(self):
582        # Conv2d with Activation and no Bias
583        configs, shapes = [], []
584
585        # conv2d(no bias) -> activation -> conv2d(no bias)
586        configs.append(
587            [
588                {"tensor_fqn": "seq.4.weight"},
589            ]
590        )
591        shapes.append((1, 52, 18, 18))
592
593        # conv2d(bias) -> activation -> conv2d(bias)
594        configs.append(
595            [
596                {"tensor_fqn": "seq.0.weight"},
597                {"tensor_fqn": "seq.2.weight"},
598            ]
599        )
600        shapes.append((1, 52, 18, 18))
601
602        # conv2d(bias) -> activation -> conv2d(no bias)
603        configs.append(
604            [
605                {"tensor_fqn": "seq.2.weight"},
606                {"tensor_fqn": "seq.4.weight"},
607            ]
608        )
609        shapes.append((1, 52, 18, 18))
610
611        # conv2d(no bias) -> activation -> conv2d(bias)
612        configs.append(
613            [
614                {"tensor_fqn": "conv2d1.weight"},
615            ]
616        )
617        shapes.append((1, 52, 18, 18))
618
619        for device in DEVICES:
620            x = torch.ones((1, 1, 28, 28), device=device)
621            for also_prune_bias in [True, False]:
622                for config, shape in zip(configs, shapes):
623                    self._test_conv2d_on_device(
624                        Conv2dActivation(),
625                        config,
626                        x,
627                        shape,
628                        torch.device(device),
629                        also_prune_bias,
630                    )
631
632    def test_prune_conv2d_padding_conv2d(self):
633        # Conv2d with Padded layers after Bias layers
634        configs, shapes = [], []
635
636        # conv(padded, bias) -> conv(padded, bias)
637        configs.append(
638            [
639                {"tensor_fqn": "seq.4.weight"},
640            ]
641        )
642        shapes.append((1, 52, 24, 24))
643
644        # conv(no bias, no pad) -> conv(padded, bias)
645        configs.append(
646            [
647                {"tensor_fqn": "seq.2.weight"},
648            ]
649        )
650        shapes.append((1, 52, 24, 24))
651
652        # conv(padded, bias) -> conv ( no bias ,no pad)
653        configs.append(
654            [
655                {"tensor_fqn": "seq.0.weight"},
656            ]
657        )
658        shapes.append((1, 52, 24, 24))
659        # conv(pad, bias) -> conv(no pad, bias)
660        configs.append(
661            [
662                {"tensor_fqn": "seq.6.weight"},
663            ]
664        )
665        shapes.append((1, 52, 24, 24))
666        # conv(no pad, bias) -> conv(pad, bias)
667        configs.append(
668            [
669                {"tensor_fqn": "seq.8.weight"},
670            ]
671        )
672        shapes.append((1, 52, 24, 24))
673
674        for device in DEVICES:
675            x = torch.ones((1, 1, 28, 28), device=device)
676            for also_prune_bias in [True, False]:
677                for config, shape in zip(configs, shapes):
678                    self._test_conv2d_on_device(
679                        Conv2dPadBias(),
680                        config,
681                        x,
682                        shape,
683                        torch.device(device),
684                        also_prune_bias,
685                    )
686
687    def test_prune_conv2d_pool_conv2d(self):
688        # Conv2d with Pooling layers
689        config = [
690            {"tensor_fqn": "seq.0.weight"},
691            {"tensor_fqn": "seq.3.weight"},
692            {"tensor_fqn": "conv2d1.weight"},
693            {"tensor_fqn": "conv2d2.weight"},
694        ]
695        shape = (1, 52, 3, 3)
696
697        for device in DEVICES:
698            x = torch.ones((1, 1, 28, 28), device=device)
699            for also_prune_bias in [True, False]:
700                self._test_conv2d_on_device(
701                    Conv2dPool(),
702                    config,
703                    x,
704                    shape,
705                    torch.device(device),
706                    also_prune_bias,
707                )
708
709    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
710    def test_complex_conv2d(self):
711        """Test fusion for models that contain Conv2d & Linear modules.
712        Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add"""
713        config = [
714            {"tensor_fqn": "seq.0.weight"},
715            {"tensor_fqn": "seq.3.weight"},
716            {"tensor_fqn": "conv2d1.weight"},
717            {"tensor_fqn": "conv2d2.weight"},
718        ]
719        shape = (1, 13)
720
721        for device in DEVICES:
722            x = torch.ones((1, 1, 28, 28), device=device)
723            for also_prune_bias in [True, False]:
724                self._test_conv2d_on_device(
725                    Conv2dPoolFlattenFunctional(),
726                    config,
727                    x,
728                    shape,
729                    torch.device(device),
730                    also_prune_bias,
731                )
732                self._test_conv2d_on_device(
733                    Conv2dPoolFlatten(),
734                    config,
735                    x,
736                    shape,
737                    torch.device(device),
738                    also_prune_bias,
739                )
740
741    def test_prune_lstm_linear_multiple_layer(self):
742        """
743        Test fusion support for LSTM(multi-layer) -> Linear
744        """
745        model = LSTMLinearModel(
746            input_dim=8,
747            hidden_dim=8,
748            output_dim=8,
749            num_layers=2,
750        )
751
752        config = [
753            {"tensor_fqn": "lstm.weight_ih_l0"},
754            {"tensor_fqn": "lstm.weight_hh_l0"},
755            {"tensor_fqn": "lstm.weight_ih_l1"},
756            {"tensor_fqn": "lstm.weight_hh_l1"},
757        ]
758
759        lstm_input = torch.ones((1, 8))
760        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
761        fx_pruner.prepare(model, config)
762
763        fx_pruner.enable_mask_update = True
764        fx_pruner.step()
765
766        model.eval()
767        _, _ = model(lstm_input)
768        pruned_model = fx_pruner.prune()
769        pruned_model.eval()
770        _, _ = pruned_model(lstm_input)
771
772        expected_params = dict(model.named_parameters())
773        for name, param in model.named_parameters():
774            assert name in expected_params
775            # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
776            # Instead we check that the weights of the new LSTM are a subset of the weights of
777            # the old LSTM
778            assert rows_are_subset(param, expected_params[name])
779            del expected_params[name]
780
781        # assert we haven't deleted any keys
782        assert len(expected_params) == 0
783
784    def test_prune_lstm_linear_single_layer(self):
785        """
786        Test fusion support for LSTM (single-layer) -> Linear
787        """
788        model = LSTMLinearModel(
789            input_dim=8,
790            hidden_dim=8,
791            output_dim=8,
792            num_layers=1,
793        )
794
795        config = [
796            {"tensor_fqn": "lstm.weight_ih_l0"},
797            {"tensor_fqn": "lstm.weight_hh_l0"},
798        ]
799
800        lstm_input = torch.ones((1, 8))
801        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
802        fx_pruner.prepare(model, config)
803        fx_pruner.enable_mask_update = True
804        fx_pruner.step()
805        model.eval()
806
807        out_expected, lstm_out_expected = model(lstm_input)
808        pruned_model = fx_pruner.prune()
809        pruned_model.eval()
810        out_pruned, lstm_out_pruned = pruned_model(lstm_input)
811        r, c = lstm_out_expected.size()
812
813        # We cannot check that y_expected == y_pruned as usual because
814        # zeros vs. missing elements yield different numerical results.
815        # Instead that we check that the pruned elements are the first half of the results
816        # since we are using a BottomHalfLSTMPruner
817        assert torch.isclose(
818            lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
819        ).all()
820        # also check that output of linear is the same shape, this means we've resized
821        # linear columns correctly.
822        assert out_expected.shape == out_pruned.shape
823
824    def test_prune_lstm_layernorm_linear_multiple_layer(self):
825        """
826        Test fusion support for LSTM(multi-layer) -> Linear
827        """
828        model = LSTMLayerNormLinearModel(
829            input_dim=8,
830            output_dim=8,
831            hidden_dim=8,
832            num_layers=2,
833        )
834
835        config = [
836            {"tensor_fqn": "lstm.weight_ih_l0"},
837            {"tensor_fqn": "lstm.weight_hh_l0"},
838            {"tensor_fqn": "lstm.weight_ih_l1"},
839            {"tensor_fqn": "lstm.weight_hh_l1"},
840        ]
841
842        lstm_input = torch.ones((1, 8))
843        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
844        fx_pruner.prepare(model, config)
845
846        fx_pruner.enable_mask_update = True
847        fx_pruner.step()
848
849        model.eval()
850        _, _ = model(lstm_input)
851        pruned_model = fx_pruner.prune()
852        pruned_model.eval()
853        _, _ = pruned_model(lstm_input)
854
855        expected_params = dict(model.named_parameters())
856        for name, param in model.named_parameters():
857            assert name in expected_params
858            # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
859            # Instead we check that the weights of the new LSTM are a subset of the weights of
860            # the old LSTM
861            assert rows_are_subset(param, expected_params[name])
862            del expected_params[name]
863
864        # assert we haven't deleted any keys
865        assert len(expected_params) == 0
866
867    def test_prune_lstm_layernorm_linear_single_layer(self):
868        """
869        Test fusion support for LSTM (single-layer) -> Linear
870        """
871        model = LSTMLinearModel(
872            input_dim=8,
873            hidden_dim=8,
874            output_dim=8,
875            num_layers=1,
876        )
877
878        config = [
879            {"tensor_fqn": "lstm.weight_ih_l0"},
880            {"tensor_fqn": "lstm.weight_hh_l0"},
881        ]
882
883        lstm_input = torch.ones((1, 8))
884        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
885        fx_pruner.prepare(model, config)
886        fx_pruner.enable_mask_update = True
887        fx_pruner.step()
888        model.eval()
889
890        out_expected, lstm_out_expected = model(lstm_input)
891        pruned_model = fx_pruner.prune()
892        pruned_model.eval()
893        out_pruned, lstm_out_pruned = pruned_model(lstm_input)
894        r, c = lstm_out_expected.size()
895
896        # We cannot check that y_expected == y_pruned as usual because
897        # zeros vs. missing elements yield different numerical results.
898        # Instead that we check that the pruned elements are the first half of the results
899        # since we are using a BottomHalfLSTMPruner
900        assert torch.isclose(
901            lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
902        ).all()
903        # also check that output of linear is the same shape, this means we've resized
904        # linear columns correctly.
905        assert out_expected.shape == out_pruned.shape
906
907
908class TestFPGMPruner(TestCase):
909    """
910    Test case for the implementation of paper:
911    `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
912    """
913
914    class SimpleConvFPGM(nn.Module):
915        def __init__(self) -> None:
916            super().__init__()
917            self.conv2d1 = nn.Conv2d(
918                in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False
919            )
920            # Manually set the filter weights for demonstration purposes
921            """
922            Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
923            Different from the norm-based decision that prunes filter with value 0.1,
924            FPGM will prune the one with value 2.0.
925            """
926            weights = torch.tensor([3.0, 2.0, 0.1])  # Weight weights for each filter
927            weights = weights[:, None, None, None]  # broadcasting
928            self.conv2d1.weight.data.copy_(
929                torch.ones(self.conv2d1.weight.shape) * weights
930            )
931
932            # Second Convolutional Layer
933            self.conv2d2 = nn.Conv2d(
934                in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False
935            )
936            weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
937            weights = weights[:, None, None, None]
938            self.conv2d2.weight.data.copy_(
939                torch.ones(self.conv2d2.weight.shape) * weights
940            )
941
942        def forward(self, x):
943            x = self.conv2d1(x)
944            x = self.conv2d2(x)
945            return x
946
947    def test_compute_distance(self, device="cpu"):
948        """Test the distance computation function"""
949        model = TestFPGMPruner.SimpleConvFPGM().to(device)
950        pruner = FPGMPruner(0.3)
951        dist_conv1 = pruner._compute_distance(model.conv2d1.weight)
952
953        # compute the distance matrix using torch.cdist
954        flattened_filters = torch.Tensor(
955            [
956                [
957                    3.0000,
958                    3.0000,
959                    3.0000,
960                    3.0000,
961                    3.0000,
962                    3.0000,
963                    3.0000,
964                    3.0000,
965                    3.0000,
966                ],
967                [
968                    2.0000,
969                    2.0000,
970                    2.0000,
971                    2.0000,
972                    2.0000,
973                    2.0000,
974                    2.0000,
975                    2.0000,
976                    2.0000,
977                ],
978                [
979                    0.1000,
980                    0.1000,
981                    0.1000,
982                    0.1000,
983                    0.1000,
984                    0.1000,
985                    0.1000,
986                    0.1000,
987                    0.1000,
988                ],
989            ]
990        )
991
992        """
993        Expected distance matrix should have the following values:
994            [0.0000, 3.0000, 8.7000],
995            [3.0000, 0.0000, 5.7000],
996            [8.7000, 5.7000, 0.0000],
997        the distance should therefore be:
998            [11.7000, 8.7000, 14.4000]
999        """
1000        expected_dist_matrix_conv1 = torch.cdist(
1001            flattened_filters, flattened_filters, p=2
1002        )
1003        expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
1004        assert torch.isclose(
1005            dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07
1006        ).all()
1007
1008    def _test_update_mask_on_single_layer(self, expected_conv1, device):
1009        """Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
1010        # test pruning with one layer of conv2d
1011        model = TestFPGMPruner.SimpleConvFPGM().to(device)
1012        x = torch.ones((1, 1, 32, 32), device=device)
1013        pruner = FPGMPruner(0.3)
1014        config = [{"tensor_fqn": "conv2d1.weight"}]
1015        pruner.prepare(model, config)
1016        pruner.enable_mask_update = True
1017        pruner.step()
1018        assert (
1019            pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item()
1020            is not False
1021        ), "do not prune the least-norm filter"
1022
1023        # fusion step
1024        pruned_model = pruner.prune()
1025
1026        pruned_y = pruned_model(x)
1027        # assert shapes
1028        expected_conv1 = expected_conv1.to(device)
1029        assert pruned_y.shape == (1, 4, 32, 32)
1030        assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
1031        assert pruned_model.conv2d2.weight.shape == (
1032            4,
1033            2,
1034            3,
1035            3,
1036        ), "conv2d2 should have input channel pruned"
1037        # assert value
1038        assert torch.isclose(
1039            pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
1040        ).all()
1041
1042    def _test_update_mask_on_multiple_layer(
1043        self, expected_conv1, expected_conv2, device
1044    ):
1045        # the second setting
1046        model = TestFPGMPruner.SimpleConvFPGM().to(device)
1047        x = torch.ones((1, 1, 32, 32), device=device)
1048        pruner = FPGMPruner(0.3)
1049        config = [
1050            {"tensor_fqn": "conv2d1.weight"},
1051            {"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5},
1052        ]
1053        pruner.prepare(model, config)
1054        pruner.enable_mask_update = True
1055        pruner.step()
1056        # Get the masks for the two least-norm filters
1057        mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
1058        mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
1059        # Check if either of the least-norm filters is not pruned
1060        assert (
1061            mask1.item() is not False or mask2.item() is not False
1062        ), "Do not prune all least-norm filters"
1063
1064        # fusion step
1065        pruned_model = pruner.prune()
1066        pruned_y = pruned_model(x)
1067        # assert shapes
1068        expected_conv1 = expected_conv1.to(device)
1069        expected_conv2 = expected_conv2.to(device)
1070        assert pruned_y.shape == (1, 2, 32, 32)
1071        assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
1072        assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
1073        # assert values
1074        assert torch.isclose(
1075            pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
1076        ).all()
1077        assert torch.isclose(
1078            pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07
1079        ).all()
1080
1081    def test_update_mask(self):
1082        weights = torch.tensor([3.0, 0.1])
1083        expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None]
1084
1085        weights = torch.tensor([7.0, 0.4])
1086        expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None]
1087
1088        for device in DEVICES:
1089            self._test_update_mask_on_single_layer(expected_conv1, device)
1090            self._test_update_mask_on_multiple_layer(
1091                expected_conv1, expected_conv2, device
1092            )
1093