1# Owner(s): ["module: unknown"] 2 3import warnings 4 5from torch import nn 6from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier 7from torch.testing._internal.common_utils import TestCase 8 9 10class ImplementedScheduler(BaseScheduler): 11 def get_sl(self): 12 if self.last_epoch > 0: 13 return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups] 14 else: 15 return list(self.base_sl) 16 17 18class TestScheduler(TestCase): 19 def test_constructor(self): 20 model = nn.Sequential(nn.Linear(16, 16)) 21 sparsifier = WeightNormSparsifier() 22 sparsifier.prepare(model, config=None) 23 scheduler = ImplementedScheduler(sparsifier) 24 25 assert scheduler.sparsifier is sparsifier 26 assert scheduler._step_count == 1 27 assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]] 28 29 def test_order_of_steps(self): 30 """Checks if the warning is thrown if the scheduler step is called 31 before the sparsifier step""" 32 33 model = nn.Sequential(nn.Linear(16, 16)) 34 sparsifier = WeightNormSparsifier() 35 sparsifier.prepare(model, config=None) 36 scheduler = ImplementedScheduler(sparsifier) 37 38 # Sparsifier step is not called 39 with self.assertWarns(UserWarning): 40 scheduler.step() 41 42 # Correct order has no warnings 43 # Note: This will trigger if other warnings are present. 44 with warnings.catch_warnings(record=True) as w: 45 sparsifier.step() 46 scheduler.step() 47 # Make sure there is no warning related to the base_scheduler 48 for warning in w: 49 fname = warning.filename 50 fname = "/".join(fname.split("/")[-5:]) 51 assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py" 52 53 def test_step(self): 54 model = nn.Sequential(nn.Linear(16, 16)) 55 sparsifier = WeightNormSparsifier() 56 sparsifier.prepare(model, config=None) 57 assert sparsifier.groups[0]["sparsity_level"] == 0.5 58 scheduler = ImplementedScheduler(sparsifier) 59 assert sparsifier.groups[0]["sparsity_level"] == 0.5 60 61 sparsifier.step() 62 scheduler.step() 63 assert sparsifier.groups[0]["sparsity_level"] == 0.25 64 65 def test_lambda_scheduler(self): 66 model = nn.Sequential(nn.Linear(16, 16)) 67 sparsifier = WeightNormSparsifier() 68 sparsifier.prepare(model, config=None) 69 assert sparsifier.groups[0]["sparsity_level"] == 0.5 70 scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10) 71 assert sparsifier.groups[0]["sparsity_level"] == 0.0 # Epoch 0 72 scheduler.step() 73 assert sparsifier.groups[0]["sparsity_level"] == 5.0 # Epoch 1 74 75 76class TestCubicScheduler(TestCase): 77 def setUp(self): 78 self.model_sparse_config = [ 79 {"tensor_fqn": "0.weight", "sparsity_level": 0.8}, 80 {"tensor_fqn": "2.weight", "sparsity_level": 0.4}, 81 ] 82 self.sorted_sparse_levels = [ 83 conf["sparsity_level"] for conf in self.model_sparse_config 84 ] 85 self.initial_sparsity = 0.1 86 self.initial_step = 3 87 88 def _make_model(self, **kwargs): 89 model = nn.Sequential( 90 nn.Linear(13, 17), 91 nn.Dropout(0.5), 92 nn.Linear(17, 3), 93 ) 94 return model 95 96 def _make_scheduler(self, model, **kwargs): 97 sparsifier = WeightNormSparsifier() 98 sparsifier.prepare(model, config=self.model_sparse_config) 99 100 scheduler_args = { 101 "init_sl": self.initial_sparsity, 102 "init_t": self.initial_step, 103 } 104 scheduler_args.update(kwargs) 105 106 scheduler = CubicSL(sparsifier, **scheduler_args) 107 return sparsifier, scheduler 108 109 @staticmethod 110 def _get_sparsity_levels(sparsifier, precision=32): 111 r"""Gets the current levels of sparsity in a sparsifier.""" 112 return [ 113 round(group["sparsity_level"], precision) for group in sparsifier.groups 114 ] 115 116 def test_constructor(self): 117 model = self._make_model() 118 sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True) 119 self.assertIs( 120 scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached" 121 ) 122 self.assertEqual( 123 scheduler._step_count, 124 1, 125 msg="Scheduler is initialized with incorrect step count", 126 ) 127 self.assertEqual( 128 scheduler.base_sl, 129 self.sorted_sparse_levels, 130 msg="Scheduler did not store the target sparsity levels correctly", 131 ) 132 133 # Value before t_0 is 0 134 self.assertEqual( 135 self._get_sparsity_levels(sparsifier), 136 scheduler._make_sure_a_list(0.0), 137 msg="Sparsifier is not reset correctly after attaching to the Scheduler", 138 ) 139 140 # Value before t_0 is s_0 141 model = self._make_model() 142 sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False) 143 self.assertEqual( 144 self._get_sparsity_levels(sparsifier), 145 scheduler._make_sure_a_list(self.initial_sparsity), 146 msg="Sparsifier is not reset correctly after attaching to the Scheduler", 147 ) 148 149 def test_step(self): 150 # For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0 151 model = self._make_model() 152 sparsifier, scheduler = self._make_scheduler( 153 model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5 154 ) 155 156 scheduler.step() 157 scheduler.step() 158 self.assertEqual( 159 scheduler._step_count, 160 3, 161 msg="Scheduler step_count is expected to increment", 162 ) 163 # Value before t_0 is supposed to be 0 164 self.assertEqual( 165 self._get_sparsity_levels(sparsifier), 166 scheduler._make_sure_a_list(0.0), 167 msg="Scheduler step updating the sparsity level before t_0", 168 ) 169 170 scheduler.step() # Step = 3 => sparsity = initial_sparsity 171 self.assertEqual( 172 self._get_sparsity_levels(sparsifier), 173 scheduler._make_sure_a_list(self.initial_sparsity), 174 msg="Sparsifier is not reset to initial sparsity at the first step", 175 ) 176 177 scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2] 178 self.assertEqual( 179 self._get_sparsity_levels(sparsifier, 1), 180 [0.3, 0.2], 181 msg="Sparsity level is not set correctly after the first step", 182 ) 183 184 current_step = scheduler._step_count - scheduler.init_t[0] - 1 185 more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step 186 for _ in range(more_steps_needed): # More steps needed to final sparsity level 187 scheduler.step() 188 self.assertEqual( 189 self._get_sparsity_levels(sparsifier), 190 self.sorted_sparse_levels, 191 msg="Sparsity level is not reaching the target level afer delta_t * n steps ", 192 ) 193