xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import warnings
4
5from torch import nn
6from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
7from torch.testing._internal.common_utils import TestCase
8
9
10class ImplementedScheduler(BaseScheduler):
11    def get_sl(self):
12        if self.last_epoch > 0:
13            return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups]
14        else:
15            return list(self.base_sl)
16
17
18class TestScheduler(TestCase):
19    def test_constructor(self):
20        model = nn.Sequential(nn.Linear(16, 16))
21        sparsifier = WeightNormSparsifier()
22        sparsifier.prepare(model, config=None)
23        scheduler = ImplementedScheduler(sparsifier)
24
25        assert scheduler.sparsifier is sparsifier
26        assert scheduler._step_count == 1
27        assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]]
28
29    def test_order_of_steps(self):
30        """Checks if the warning is thrown if the scheduler step is called
31        before the sparsifier step"""
32
33        model = nn.Sequential(nn.Linear(16, 16))
34        sparsifier = WeightNormSparsifier()
35        sparsifier.prepare(model, config=None)
36        scheduler = ImplementedScheduler(sparsifier)
37
38        # Sparsifier step is not called
39        with self.assertWarns(UserWarning):
40            scheduler.step()
41
42        # Correct order has no warnings
43        # Note: This will trigger if other warnings are present.
44        with warnings.catch_warnings(record=True) as w:
45            sparsifier.step()
46            scheduler.step()
47            # Make sure there is no warning related to the base_scheduler
48            for warning in w:
49                fname = warning.filename
50                fname = "/".join(fname.split("/")[-5:])
51                assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py"
52
53    def test_step(self):
54        model = nn.Sequential(nn.Linear(16, 16))
55        sparsifier = WeightNormSparsifier()
56        sparsifier.prepare(model, config=None)
57        assert sparsifier.groups[0]["sparsity_level"] == 0.5
58        scheduler = ImplementedScheduler(sparsifier)
59        assert sparsifier.groups[0]["sparsity_level"] == 0.5
60
61        sparsifier.step()
62        scheduler.step()
63        assert sparsifier.groups[0]["sparsity_level"] == 0.25
64
65    def test_lambda_scheduler(self):
66        model = nn.Sequential(nn.Linear(16, 16))
67        sparsifier = WeightNormSparsifier()
68        sparsifier.prepare(model, config=None)
69        assert sparsifier.groups[0]["sparsity_level"] == 0.5
70        scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
71        assert sparsifier.groups[0]["sparsity_level"] == 0.0  # Epoch 0
72        scheduler.step()
73        assert sparsifier.groups[0]["sparsity_level"] == 5.0  # Epoch 1
74
75
76class TestCubicScheduler(TestCase):
77    def setUp(self):
78        self.model_sparse_config = [
79            {"tensor_fqn": "0.weight", "sparsity_level": 0.8},
80            {"tensor_fqn": "2.weight", "sparsity_level": 0.4},
81        ]
82        self.sorted_sparse_levels = [
83            conf["sparsity_level"] for conf in self.model_sparse_config
84        ]
85        self.initial_sparsity = 0.1
86        self.initial_step = 3
87
88    def _make_model(self, **kwargs):
89        model = nn.Sequential(
90            nn.Linear(13, 17),
91            nn.Dropout(0.5),
92            nn.Linear(17, 3),
93        )
94        return model
95
96    def _make_scheduler(self, model, **kwargs):
97        sparsifier = WeightNormSparsifier()
98        sparsifier.prepare(model, config=self.model_sparse_config)
99
100        scheduler_args = {
101            "init_sl": self.initial_sparsity,
102            "init_t": self.initial_step,
103        }
104        scheduler_args.update(kwargs)
105
106        scheduler = CubicSL(sparsifier, **scheduler_args)
107        return sparsifier, scheduler
108
109    @staticmethod
110    def _get_sparsity_levels(sparsifier, precision=32):
111        r"""Gets the current levels of sparsity in a sparsifier."""
112        return [
113            round(group["sparsity_level"], precision) for group in sparsifier.groups
114        ]
115
116    def test_constructor(self):
117        model = self._make_model()
118        sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
119        self.assertIs(
120            scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached"
121        )
122        self.assertEqual(
123            scheduler._step_count,
124            1,
125            msg="Scheduler is initialized with incorrect step count",
126        )
127        self.assertEqual(
128            scheduler.base_sl,
129            self.sorted_sparse_levels,
130            msg="Scheduler did not store the target sparsity levels correctly",
131        )
132
133        # Value before t_0 is 0
134        self.assertEqual(
135            self._get_sparsity_levels(sparsifier),
136            scheduler._make_sure_a_list(0.0),
137            msg="Sparsifier is not reset correctly after attaching to the Scheduler",
138        )
139
140        # Value before t_0 is s_0
141        model = self._make_model()
142        sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False)
143        self.assertEqual(
144            self._get_sparsity_levels(sparsifier),
145            scheduler._make_sure_a_list(self.initial_sparsity),
146            msg="Sparsifier is not reset correctly after attaching to the Scheduler",
147        )
148
149    def test_step(self):
150        # For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
151        model = self._make_model()
152        sparsifier, scheduler = self._make_scheduler(
153            model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5
154        )
155
156        scheduler.step()
157        scheduler.step()
158        self.assertEqual(
159            scheduler._step_count,
160            3,
161            msg="Scheduler step_count is expected to increment",
162        )
163        # Value before t_0 is supposed to be 0
164        self.assertEqual(
165            self._get_sparsity_levels(sparsifier),
166            scheduler._make_sure_a_list(0.0),
167            msg="Scheduler step updating the sparsity level before t_0",
168        )
169
170        scheduler.step()  # Step = 3  =>  sparsity = initial_sparsity
171        self.assertEqual(
172            self._get_sparsity_levels(sparsifier),
173            scheduler._make_sure_a_list(self.initial_sparsity),
174            msg="Sparsifier is not reset to initial sparsity at the first step",
175        )
176
177        scheduler.step()  # Step = 4  =>  sparsity ~ [0.3, 0.2]
178        self.assertEqual(
179            self._get_sparsity_levels(sparsifier, 1),
180            [0.3, 0.2],
181            msg="Sparsity level is not set correctly after the first step",
182        )
183
184        current_step = scheduler._step_count - scheduler.init_t[0] - 1
185        more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
186        for _ in range(more_steps_needed):  # More steps needed to final sparsity level
187            scheduler.step()
188        self.assertEqual(
189            self._get_sparsity_levels(sparsifier),
190            self.sorted_sparse_levels,
191            msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
192        )
193