1# mypy: allow-untyped-defs 2import importlib 3import math 4import unittest 5import warnings 6from typing import List 7 8import torch 9import torch.nn as nn 10from torch.ao.pruning._experimental.data_scheduler.base_data_scheduler import ( 11 BaseDataScheduler, 12) 13from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import ( 14 SUPPORTED_TYPES, 15) 16from torch.ao.pruning._experimental.data_sparsifier.data_norm_sparsifier import ( 17 DataNormSparsifier, 18) 19from torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks._data_sparstity_utils import ( 20 _get_valid_name, 21) 22from torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity import ( 23 PostTrainingDataSparsity, 24 TrainingAwareDataSparsity, 25) 26from torch.nn.utils.parametrize import is_parametrized 27from torch.testing._internal.common_utils import run_tests, TestCase 28 29 30class DummyModel(nn.Module): 31 def __init__(self, iC: int, oC: List[int]): 32 super().__init__() 33 self.linears = nn.Sequential() 34 i = iC 35 for idx, c in enumerate(oC): 36 self.linears.append(nn.Linear(i, c, bias=False)) 37 if idx < len(oC) - 1: 38 self.linears.append(nn.ReLU()) 39 i = c 40 41 42def _make_lightning_module(iC: int, oC: List[int]): 43 import pytorch_lightning as pl # type: ignore[import] 44 45 class DummyLightningModule(pl.LightningModule): 46 def __init__(self, ic: int, oC: List[int]): 47 super().__init__() 48 self.model = DummyModel(iC, oC) 49 50 def forward(self): 51 pass 52 53 return DummyLightningModule(iC, oC) 54 55 56class StepSLScheduler(BaseDataScheduler): 57 """The sparsity param of each data group is multiplied by gamma every step_size epochs.""" 58 59 def __init__( 60 self, 61 data_sparsifier, 62 schedule_param="sparsity_level", 63 step_size=1, 64 gamma=2, 65 last_epoch=-1, 66 verbose=False, 67 ): 68 self.gamma = gamma 69 self.step_size = step_size 70 super().__init__(data_sparsifier, schedule_param, last_epoch, verbose) 71 72 def get_schedule_param(self): 73 if not self._get_sp_called_within_step: 74 warnings.warn( 75 "To get the last learning rate computed by the scheduler, " 76 "please use `get_last_lr()`.", 77 UserWarning, 78 ) 79 data_groups = self.data_sparsifier.data_groups 80 if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): 81 return { 82 name: config[self.schedule_param] 83 for name, config in data_groups.items() 84 } 85 86 return { 87 name: config[self.schedule_param] * self.gamma 88 for name, config in data_groups.items() 89 } 90 91 92class TestPostTrainingCallback(TestCase): 93 def _check_on_fit_end(self, pl_module, callback, sparsifier_args): 94 """Makes sure that each component of is working as expected while calling the 95 post-training callback. 96 Specifically, check the following - 97 1. sparsifier config is the same as input config 98 2. data sparsifier is correctly attached to the model 99 3. sparsity is achieved after .step() 100 4. non-sparsified values are the same as original values 101 """ 102 callback.on_fit_end(42, pl_module) # 42 is a dummy value 103 104 # check sparsifier config 105 for key, value in sparsifier_args.items(): 106 assert callback.data_sparsifier.defaults[key] == value 107 108 # assert that the model is correctly attached to the sparsifier 109 for name, param in pl_module.model.named_parameters(): 110 valid_name = _get_valid_name(name) 111 if type(param) not in SUPPORTED_TYPES: 112 assert valid_name not in callback.data_sparsifier.state 113 assert valid_name not in callback.data_sparsifier.data_groups 114 continue 115 assert valid_name in callback.data_sparsifier.data_groups 116 assert valid_name in callback.data_sparsifier.state 117 118 mask = callback.data_sparsifier.get_mask(name=valid_name) 119 120 # assert that some level of sparsity is achieved 121 assert (1.0 - mask.float().mean()) > 0.0 122 123 # make sure that non-zero values in data after squash mask are equal to original values 124 sparsified_data = callback.data_sparsifier.get_data( 125 name=valid_name, return_original=False 126 ) 127 assert torch.all( 128 sparsified_data[sparsified_data != 0] == param[sparsified_data != 0] 129 ) 130 131 @unittest.skipIf( 132 not importlib.util.find_spec("pytorch_lightning"), "No pytorch_lightning" 133 ) 134 def test_post_training_callback(self): 135 sparsifier_args = { 136 "sparsity_level": 0.5, 137 "sparse_block_shape": (1, 4), 138 "zeros_per_block": 4, 139 } 140 callback = PostTrainingDataSparsity(DataNormSparsifier, sparsifier_args) 141 pl_module = _make_lightning_module(100, [128, 256, 16]) 142 143 self._check_on_fit_end(pl_module, callback, sparsifier_args) 144 145 146class TestTrainingAwareCallback(TestCase): 147 """Class to test in-training version of lightning callback 148 Simulates model training and makes sure that each hook is doing what is expected 149 """ 150 151 def _check_on_train_start( 152 self, pl_module, callback, sparsifier_args, scheduler_args 153 ): 154 """Makes sure that the data_sparsifier and data_scheduler objects are being created 155 correctly. 156 Basically, confirms that the input args and sparsifier/scheduler args are in-line. 157 """ 158 159 callback.on_train_start(42, pl_module) # 42 is a dummy value 160 161 # sparsifier and scheduler instantiated 162 assert ( 163 callback.data_scheduler is not None and callback.data_sparsifier is not None 164 ) 165 166 # data sparsifier args are correct 167 for key, value in sparsifier_args.items(): 168 assert callback.data_sparsifier.defaults[key] == value 169 170 # data scheduler args are correct 171 for key, value in scheduler_args.items(): 172 assert getattr(callback.data_scheduler, key) == value 173 174 def _simulate_update_param_model(self, pl_module): 175 """This function might not be needed as the model is being copied 176 during train_epoch_end() but good to have if things change in the future 177 """ 178 for _, param in pl_module.model.named_parameters(): 179 param.data = param + 1 180 181 def _check_on_train_epoch_start(self, pl_module, callback): 182 """Basically ensures that the sparsifier's state is correctly being restored. 183 The state_dict() comparison is needed. Consider the flow - 184 185 **Epoch: 1** 186 1. on_train_epoch_start(): Nothing happens (for now) 187 2. on_train_epoch_end(): 188 a) the model is copied into the data_sparsifier 189 b) .step() is called 190 c) internally, the state of each layer of the model inside 191 data sparsifier changes 192 193 **Epoch: 2** 194 1. on_train_epoch_start(): Assume nothing happens 195 2. on_train_epoch_end(): 196 a) the model is copied into the data_sparsifier. 197 But wait! you need the config to attach layer 198 of the module to the sparsifier. If config is None, 199 the data_sparsifier uses the default config which we 200 do not want as the config of each layer changes after 201 .step() 202 203 Hence, we need to dump and restore the state_dict() everytime because we're 204 copying the model after each epoch. 205 Hence, it is essential to make sure that the sparsifier's state_dict() is being 206 correctly dumped and restored. 207 208 """ 209 # check if each component of state dict is being loaded correctly 210 callback.on_train_epoch_start(42, pl_module) 211 if callback.data_sparsifier_state_dict is None: 212 return 213 214 data_sparsifier_state_dict = callback.data_sparsifier.state_dict() 215 216 # compare container objects 217 container_obj1 = data_sparsifier_state_dict["_container"] 218 container_obj2 = callback.data_sparsifier_state_dict["_container"] 219 assert len(container_obj1) == len(container_obj2) 220 for key, value in container_obj2.items(): 221 assert key in container_obj1 222 assert torch.all(value == container_obj1[key]) 223 224 # compare state objects 225 state_obj1 = data_sparsifier_state_dict["state"] 226 state_obj2 = callback.data_sparsifier_state_dict["state"] 227 assert len(state_obj1) == len(state_obj2) 228 for key, value in state_obj2.items(): 229 assert key in state_obj1 230 assert "mask" in value and "mask" in state_obj1[key] 231 assert torch.all(value["mask"] == state_obj1[key]["mask"]) 232 233 # compare data_groups dict 234 data_grp1 = data_sparsifier_state_dict["data_groups"] 235 data_grp2 = callback.data_sparsifier_state_dict["data_groups"] 236 assert len(data_grp1) == len(data_grp2) 237 for key, value in data_grp2.items(): 238 assert key in data_grp1 239 assert value == data_grp1[key] 240 241 def _check_on_train_epoch_end(self, pl_module, callback): 242 """Checks the following - 243 1. sparsity is correctly being achieved after .step() 244 2. scheduler and data_sparsifier sparsity levels are in-line 245 """ 246 callback.on_train_epoch_end(42, pl_module) 247 data_scheduler = callback.data_scheduler 248 base_sl = data_scheduler.base_param 249 250 for name, _ in pl_module.model.named_parameters(): 251 valid_name = _get_valid_name(name) 252 mask = callback.data_sparsifier.get_mask(name=valid_name) 253 254 # check sparsity levels 255 assert (1.0 - mask.float().mean()) > 0 # some sparsity level achieved 256 257 last_sl = data_scheduler.get_last_param() 258 last_epoch = data_scheduler.last_epoch 259 260 # check sparsity levels of scheduler 261 log_last_sl = math.log(last_sl[valid_name]) 262 log_actual_sl = math.log( 263 base_sl[valid_name] * (data_scheduler.gamma**last_epoch) 264 ) 265 assert log_last_sl == log_actual_sl 266 267 def _check_on_train_end(self, pl_module, callback): 268 """Confirms that the mask is squashed after the training ends 269 This is achieved by making sure that each parameter in the internal container 270 are not parametrized. 271 """ 272 callback.on_train_end(42, pl_module) 273 274 # check that the masks have been squashed 275 for name, _ in pl_module.model.named_parameters(): 276 valid_name = _get_valid_name(name) 277 assert not is_parametrized(callback.data_sparsifier._continer, valid_name) 278 279 @unittest.skipIf( 280 not importlib.util.find_spec("pytorch_lightning"), "No pytorch_lightning" 281 ) 282 def test_train_aware_callback(self): 283 sparsifier_args = { 284 "sparsity_level": 0.5, 285 "sparse_block_shape": (1, 4), 286 "zeros_per_block": 4, 287 } 288 scheduler_args = {"gamma": 2, "step_size": 1} 289 290 callback = TrainingAwareDataSparsity( 291 data_sparsifier_class=DataNormSparsifier, 292 data_sparsifier_args=sparsifier_args, 293 data_scheduler_class=StepSLScheduler, 294 data_scheduler_args=scheduler_args, 295 ) 296 297 pl_module = _make_lightning_module(100, [128, 256, 16]) 298 299 # simulate the training process and check all steps 300 self._check_on_train_start(pl_module, callback, sparsifier_args, scheduler_args) 301 302 num_epochs = 5 303 for _ in range(0, num_epochs): 304 self._check_on_train_epoch_start(pl_module, callback) 305 self._simulate_update_param_model(pl_module) 306 self._check_on_train_epoch_end(pl_module, callback) 307 308 309if __name__ == "__main__": 310 run_tests() 311