1# mypy: allow-untyped-defs 2from collections import defaultdict 3from copy import deepcopy 4from typing import Any, Dict, Optional, TYPE_CHECKING 5 6import pytorch_lightning as pl # type: ignore[import] 7 8from ._data_sparstity_utils import ( 9 _attach_model_to_data_sparsifier, 10 _get_valid_name, 11 _log_sparsified_level, 12) 13 14 15if TYPE_CHECKING: 16 import torch 17 18 19class PostTrainingDataSparsity(pl.callbacks.Callback): 20 """Lightning callback that enables post-training sparsity. 21 22 This callback aims to sparsify the model inside lightning module after training. 23 **Note that the model is copied and then sparsified, so the existing model is not modified** 24 25 The sparsified model can be used for comparison and can be accessed using 26 <callback_obj>.sparsified 27 28 Args: 29 data_sparsifier_class (some implemented class of BaseDataSparsifier) 30 The data sparsifier object of this class is created when the 31 training starts. 32 Note: Objects should not be passed in here as they are created 33 once the training completes. 34 35 data_sparsifier_args (Dict) 36 Dictionary of args to be passed to the data sparsifier. 37 Note: data_list arg should be ignored 38 39 Hooks implemented: 40 on_fit_end() 41 1. copies the model and attaches it to the sparsifier 42 2. sparsier step() is called 43 3. squashes the mask() 44 """ 45 46 def __init__(self, data_sparsifier_class, data_sparsifier_args): 47 super().__init__() 48 self.data_sparsifier_class = data_sparsifier_class 49 self.data_sparsifier_args = data_sparsifier_args 50 self.data_sparsifier: Any = None 51 self.sparsified: Optional[torch.nn.Module] = None 52 53 def on_fit_end(self, trainer, pl_module) -> None: 54 self.sparsified = deepcopy(pl_module.model).eval() 55 self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) 56 57 _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier) 58 59 self.data_sparsifier.step() 60 61 self.data_sparsifier.squash_mask() # currently squashes params for all mask 62 63 _log_sparsified_level(self.sparsified, self.data_sparsifier) 64 65 66class TrainingAwareDataSparsity(pl.callbacks.Callback): 67 """Lightning callback that enables in-training sparsity. 68 69 This callback aims to sparsify the model inside lightning module during training. 70 **Note that the model is copied and then sparsified, so the existing model is not modified** 71 72 The sparsified model can be used for comparison and can be accessed using 73 <callback_obj>.sparsified 74 75 Args: 76 data_sparsifier_class (some implemented class of BaseDataSparsifier) 77 The data sparsifier object of this class is created when the 78 training starts. 79 Note: Objects should not be passed in here as they are created 80 when the training starts. 81 82 data_sparsifier_args (Dict) 83 Dictionary of args to be passed to the data sparsifier. 84 Note: data_list arg should be ignored 85 86 data_scheduler_class (some implemented class of BaseDataScheduler) 87 The data scheduler of this class is created when the training starts 88 Note: Objects should not be passed in here as they are created 89 when the training starts. 90 91 data_scheduler_args(Dict) 92 Dictionary of args to be passed to the data scheduler. 93 **Note: data_sparsifier arg should be ignored as the recipe 94 creates and pass sparsifier object into the class** 95 96 Hooks implemented: 97 on_train_start() 98 Data sparsifier and scheduler objects are created. 99 Pytorch model attached to the sparsifier 100 101 on_train_epoch_start() 102 Loads the state_dict of the data sparsifier 103 104 on_train_epoch_end() 105 1. Copies the model and attaches it to the sparsifier 106 2. sparsifier step() and scheduler step() 107 3. Dump state_dict of the current sparsifier 108 109 on_train_end() 110 squash mask 111 """ 112 113 def __init__( 114 self, 115 data_sparsifier_class, 116 data_sparsifier_args, 117 data_scheduler_class, 118 data_scheduler_args, 119 ): 120 super().__init__() 121 # data sparsifier objects 122 self.data_sparsifier_class = data_sparsifier_class 123 self.data_sparsifier_args = data_sparsifier_args 124 125 # scheduler objects 126 self.data_scheduler_class = data_scheduler_class 127 self.data_scheduler_args = data_scheduler_args 128 129 # fields 130 self.data_sparsifier: Any = None 131 self.data_scheduler: Any = None 132 self.sparsified: Optional[torch.nn.Module] = None 133 134 self.data_sparsifier_state_dict: Any = None 135 136 def on_train_start(self, trainer, pl_module) -> None: 137 # create sparsifier 138 self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) 139 self.sparsified = deepcopy(pl_module.model) 140 141 _attach_model_to_data_sparsifier( 142 self.sparsified, self.data_sparsifier 143 ) # just to populate the base_sl in the scheduler 144 145 # create scheduler 146 args = deepcopy(self.data_scheduler_args) 147 args["data_sparsifier"] = self.data_sparsifier 148 self.data_scheduler = self.data_scheduler_class(**args) 149 150 def on_train_epoch_start(self, trainer, pl_module): 151 if self.data_sparsifier_state_dict is None: 152 return # probably first epoch 153 154 # load the existing config for each data 155 self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict) 156 157 def __create_config_based_on_state(self, pl_module): 158 config: Dict = defaultdict() 159 if self.data_sparsifier_state_dict is None: 160 return config 161 for name, _ in pl_module.model.named_parameters(): 162 valid_name = _get_valid_name(name) 163 config[valid_name] = self.data_sparsifier.data_groups[valid_name] 164 165 return config 166 167 def on_train_epoch_end(self, trainer, pl_module): 168 self.sparsified = deepcopy(pl_module.model) 169 config = self.__create_config_based_on_state(pl_module) 170 171 # attach model to the data sparsifier 172 _attach_model_to_data_sparsifier( 173 self.sparsified, self.data_sparsifier, config=config 174 ) 175 self.data_sparsifier.step() 176 self.data_scheduler.step() 177 178 self.data_sparsifier_state_dict = self.data_sparsifier.state_dict() 179 180 def on_train_end(self, trainer, pl_module): 181 self.data_sparsifier.squash_mask() 182