1# mypy: allow-untyped-defs
2import importlib
3import math
4import unittest
5import warnings
6from typing import List
7
8import torch
9import torch.nn as nn
10from torch.ao.pruning._experimental.data_scheduler.base_data_scheduler import (
11    BaseDataScheduler,
12)
13from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import (
14    SUPPORTED_TYPES,
15)
16from torch.ao.pruning._experimental.data_sparsifier.data_norm_sparsifier import (
17    DataNormSparsifier,
18)
19from torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks._data_sparstity_utils import (
20    _get_valid_name,
21)
22from torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity import (
23    PostTrainingDataSparsity,
24    TrainingAwareDataSparsity,
25)
26from torch.nn.utils.parametrize import is_parametrized
27from torch.testing._internal.common_utils import run_tests, TestCase
28
29
30class DummyModel(nn.Module):
31    def __init__(self, iC: int, oC: List[int]):
32        super().__init__()
33        self.linears = nn.Sequential()
34        i = iC
35        for idx, c in enumerate(oC):
36            self.linears.append(nn.Linear(i, c, bias=False))
37            if idx < len(oC) - 1:
38                self.linears.append(nn.ReLU())
39            i = c
40
41
42def _make_lightning_module(iC: int, oC: List[int]):
43    import pytorch_lightning as pl  # type: ignore[import]
44
45    class DummyLightningModule(pl.LightningModule):
46        def __init__(self, ic: int, oC: List[int]):
47            super().__init__()
48            self.model = DummyModel(iC, oC)
49
50        def forward(self):
51            pass
52
53    return DummyLightningModule(iC, oC)
54
55
56class StepSLScheduler(BaseDataScheduler):
57    """The sparsity param of each data group is multiplied by gamma every step_size epochs."""
58
59    def __init__(
60        self,
61        data_sparsifier,
62        schedule_param="sparsity_level",
63        step_size=1,
64        gamma=2,
65        last_epoch=-1,
66        verbose=False,
67    ):
68        self.gamma = gamma
69        self.step_size = step_size
70        super().__init__(data_sparsifier, schedule_param, last_epoch, verbose)
71
72    def get_schedule_param(self):
73        if not self._get_sp_called_within_step:
74            warnings.warn(
75                "To get the last learning rate computed by the scheduler, "
76                "please use `get_last_lr()`.",
77                UserWarning,
78            )
79        data_groups = self.data_sparsifier.data_groups
80        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
81            return {
82                name: config[self.schedule_param]
83                for name, config in data_groups.items()
84            }
85
86        return {
87            name: config[self.schedule_param] * self.gamma
88            for name, config in data_groups.items()
89        }
90
91
92class TestPostTrainingCallback(TestCase):
93    def _check_on_fit_end(self, pl_module, callback, sparsifier_args):
94        """Makes sure that each component of is working as expected while calling the
95        post-training callback.
96        Specifically, check the following -
97            1. sparsifier config is the same as input config
98            2. data sparsifier is correctly attached to the model
99            3. sparsity is achieved after .step()
100            4. non-sparsified values are the same as original values
101        """
102        callback.on_fit_end(42, pl_module)  # 42 is a dummy value
103
104        # check sparsifier config
105        for key, value in sparsifier_args.items():
106            assert callback.data_sparsifier.defaults[key] == value
107
108        # assert that the model is correctly attached to the sparsifier
109        for name, param in pl_module.model.named_parameters():
110            valid_name = _get_valid_name(name)
111            if type(param) not in SUPPORTED_TYPES:
112                assert valid_name not in callback.data_sparsifier.state
113                assert valid_name not in callback.data_sparsifier.data_groups
114                continue
115            assert valid_name in callback.data_sparsifier.data_groups
116            assert valid_name in callback.data_sparsifier.state
117
118            mask = callback.data_sparsifier.get_mask(name=valid_name)
119
120            # assert that some level of sparsity is achieved
121            assert (1.0 - mask.float().mean()) > 0.0
122
123            # make sure that non-zero values in data after squash mask are equal to original values
124            sparsified_data = callback.data_sparsifier.get_data(
125                name=valid_name, return_original=False
126            )
127            assert torch.all(
128                sparsified_data[sparsified_data != 0] == param[sparsified_data != 0]
129            )
130
131    @unittest.skipIf(
132        not importlib.util.find_spec("pytorch_lightning"), "No pytorch_lightning"
133    )
134    def test_post_training_callback(self):
135        sparsifier_args = {
136            "sparsity_level": 0.5,
137            "sparse_block_shape": (1, 4),
138            "zeros_per_block": 4,
139        }
140        callback = PostTrainingDataSparsity(DataNormSparsifier, sparsifier_args)
141        pl_module = _make_lightning_module(100, [128, 256, 16])
142
143        self._check_on_fit_end(pl_module, callback, sparsifier_args)
144
145
146class TestTrainingAwareCallback(TestCase):
147    """Class to test in-training version of lightning callback
148    Simulates model training and makes sure that each hook is doing what is expected
149    """
150
151    def _check_on_train_start(
152        self, pl_module, callback, sparsifier_args, scheduler_args
153    ):
154        """Makes sure that the data_sparsifier and data_scheduler objects are being created
155        correctly.
156        Basically, confirms that the input args and sparsifier/scheduler args are in-line.
157        """
158
159        callback.on_train_start(42, pl_module)  # 42 is a dummy value
160
161        # sparsifier and scheduler instantiated
162        assert (
163            callback.data_scheduler is not None and callback.data_sparsifier is not None
164        )
165
166        # data sparsifier args are correct
167        for key, value in sparsifier_args.items():
168            assert callback.data_sparsifier.defaults[key] == value
169
170        # data scheduler args are correct
171        for key, value in scheduler_args.items():
172            assert getattr(callback.data_scheduler, key) == value
173
174    def _simulate_update_param_model(self, pl_module):
175        """This function might not be needed as the model is being copied
176        during train_epoch_end() but good to have if things change in the future
177        """
178        for _, param in pl_module.model.named_parameters():
179            param.data = param + 1
180
181    def _check_on_train_epoch_start(self, pl_module, callback):
182        """Basically ensures that the sparsifier's state is correctly being restored.
183        The state_dict() comparison is needed. Consider the flow -
184
185        **Epoch: 1**
186            1. on_train_epoch_start(): Nothing happens (for now)
187            2. on_train_epoch_end():
188                a) the model is copied into the data_sparsifier
189                b) .step() is called
190                c) internally, the state of each layer of the model inside
191                   data sparsifier changes
192
193        **Epoch: 2**
194            1. on_train_epoch_start(): Assume nothing happens
195            2. on_train_epoch_end():
196                a) the model is copied into the data_sparsifier.
197                   But wait! you need the config to attach layer
198                   of the module to the sparsifier. If config is None,
199                   the data_sparsifier uses the default config which we
200                   do not want as the config of each layer changes after
201                   .step()
202
203        Hence, we need to dump and restore the state_dict() everytime because we're
204        copying the model after each epoch.
205        Hence, it is essential to make sure that the sparsifier's state_dict() is being
206        correctly dumped and restored.
207
208        """
209        # check if each component of state dict is being loaded correctly
210        callback.on_train_epoch_start(42, pl_module)
211        if callback.data_sparsifier_state_dict is None:
212            return
213
214        data_sparsifier_state_dict = callback.data_sparsifier.state_dict()
215
216        # compare container objects
217        container_obj1 = data_sparsifier_state_dict["_container"]
218        container_obj2 = callback.data_sparsifier_state_dict["_container"]
219        assert len(container_obj1) == len(container_obj2)
220        for key, value in container_obj2.items():
221            assert key in container_obj1
222            assert torch.all(value == container_obj1[key])
223
224        # compare state objects
225        state_obj1 = data_sparsifier_state_dict["state"]
226        state_obj2 = callback.data_sparsifier_state_dict["state"]
227        assert len(state_obj1) == len(state_obj2)
228        for key, value in state_obj2.items():
229            assert key in state_obj1
230            assert "mask" in value and "mask" in state_obj1[key]
231            assert torch.all(value["mask"] == state_obj1[key]["mask"])
232
233        # compare data_groups dict
234        data_grp1 = data_sparsifier_state_dict["data_groups"]
235        data_grp2 = callback.data_sparsifier_state_dict["data_groups"]
236        assert len(data_grp1) == len(data_grp2)
237        for key, value in data_grp2.items():
238            assert key in data_grp1
239            assert value == data_grp1[key]
240
241    def _check_on_train_epoch_end(self, pl_module, callback):
242        """Checks the following -
243        1. sparsity is correctly being achieved after .step()
244        2. scheduler and data_sparsifier sparsity levels are in-line
245        """
246        callback.on_train_epoch_end(42, pl_module)
247        data_scheduler = callback.data_scheduler
248        base_sl = data_scheduler.base_param
249
250        for name, _ in pl_module.model.named_parameters():
251            valid_name = _get_valid_name(name)
252            mask = callback.data_sparsifier.get_mask(name=valid_name)
253
254            # check sparsity levels
255            assert (1.0 - mask.float().mean()) > 0  # some sparsity level achieved
256
257            last_sl = data_scheduler.get_last_param()
258            last_epoch = data_scheduler.last_epoch
259
260            # check sparsity levels of scheduler
261            log_last_sl = math.log(last_sl[valid_name])
262            log_actual_sl = math.log(
263                base_sl[valid_name] * (data_scheduler.gamma**last_epoch)
264            )
265            assert log_last_sl == log_actual_sl
266
267    def _check_on_train_end(self, pl_module, callback):
268        """Confirms that the mask is squashed after the training ends
269        This is achieved by making sure that each parameter in the internal container
270        are not parametrized.
271        """
272        callback.on_train_end(42, pl_module)
273
274        # check that the masks have been squashed
275        for name, _ in pl_module.model.named_parameters():
276            valid_name = _get_valid_name(name)
277            assert not is_parametrized(callback.data_sparsifier._continer, valid_name)
278
279    @unittest.skipIf(
280        not importlib.util.find_spec("pytorch_lightning"), "No pytorch_lightning"
281    )
282    def test_train_aware_callback(self):
283        sparsifier_args = {
284            "sparsity_level": 0.5,
285            "sparse_block_shape": (1, 4),
286            "zeros_per_block": 4,
287        }
288        scheduler_args = {"gamma": 2, "step_size": 1}
289
290        callback = TrainingAwareDataSparsity(
291            data_sparsifier_class=DataNormSparsifier,
292            data_sparsifier_args=sparsifier_args,
293            data_scheduler_class=StepSLScheduler,
294            data_scheduler_args=scheduler_args,
295        )
296
297        pl_module = _make_lightning_module(100, [128, 256, 16])
298
299        # simulate the training process and check all steps
300        self._check_on_train_start(pl_module, callback, sparsifier_args, scheduler_args)
301
302        num_epochs = 5
303        for _ in range(0, num_epochs):
304            self._check_on_train_epoch_start(pl_module, callback)
305            self._simulate_update_param_model(pl_module)
306            self._check_on_train_epoch_end(pl_module, callback)
307
308
309if __name__ == "__main__":
310    run_tests()
311