1# mypy: allow-untyped-defs 2import abc 3import copy 4import sys 5import warnings 6from collections import defaultdict 7from typing import Any, Dict, List, Optional, Tuple 8 9import torch 10from torch import nn 11from torch.ao.pruning.sparsifier import base_sparsifier, utils 12from torch.nn.utils import parametrize 13 14 15if not sys.warnoptions: 16 # to suppress repeated warnings when being used in a training loop. 17 warnings.simplefilter("once") 18 19__all__ = ["BaseDataSparsifier"] 20 21EMBEDDING_TYPES = { 22 nn.Embedding, 23 nn.EmbeddingBag, 24} 25 26SUPPORTED_TYPES = { 27 torch.Tensor, 28 nn.Parameter, 29 *EMBEDDING_TYPES, 30} 31 32 33class _Container(nn.Module): 34 pass 35 36 37class BaseDataSparsifier(base_sparsifier.BaseSparsifier): 38 r""" 39 Base Data Sparsifier class for all Data sparsifiers. 40 The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above) 41 to prepare for sparsification. 42 In this case, mask (and parametrizations) is owned by the class and not by the user. 43 Specifically, the container object inside the class maintains the mask and parametrizations of the input data 44 45 Args: 46 data_list (list of tuples) 47 list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES 48 for type of data. Internally, a container module handles the data sparsification. 49 50 defaults (dict) 51 default configurations will be attached to the 52 configuration. Only the keys that don't exist in the `config` will 53 be updated. 54 Example:: 55 >>> # xdoctest: +SKIP 56 >>> data_list = [('tensor_1', torch.randn(3,3)), ('tensor_2', torch.randn(4,4))] 57 >>> defaults = {'sparsity_level': 0.7} 58 >>> sparsifier = DerivedDataSparsifier(data_list = data_list, **defaults) # Some sparsifier that inherits BaseDataSparsifier 59 >>> new_tensor_to_add = {'name': 'tensor_3', 'data': torch.randn(5,5), 'sparsity_level': 0.3} 60 >>> sparsifier.add_data(**new_tensor_to_add) 61 >>> # tensor_1 and tensor_2 will have sparsity_level of 0.7 but tensor_3 will have sparsity_level=0.3 62 """ 63 64 def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults): 65 super().__init__(defaults=defaults) 66 67 self._container = _Container() 68 69 self.data_groups: Dict[str, Dict] = defaultdict(dict) # name -> {**config} 70 if data_list is not None: 71 # add data with default config here 72 [self.add_data(name, data, **self.defaults) for name, data in data_list] 73 74 def prepare(self): 75 raise NotImplementedError("this function is undefined for this class") 76 77 def _extract_weight(self, data): 78 # extract the weight parameter instead of underlying data 79 if type(data) in [torch.Tensor, nn.Parameter]: 80 return data 81 elif type(data) in EMBEDDING_TYPES: 82 return data.weight 83 84 def add_data(self, name: str, data, reuse_mask=True, **config): 85 r"""Configures and parametrizes the internal container model with name and data. 86 87 **Note**: 88 1. If the data with name already exists, it replaces the data. 89 2. While replacing, the old mask is reused when `reuse_mask=True` 90 3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data. 91 4. By default, the config of the replaced data is used as config for the replacing data, unless something 92 is specified in the config dictionary. 93 """ 94 assert ( 95 type(data) in SUPPORTED_TYPES 96 ), "specified data type not supported at the moment" 97 local_args = copy.deepcopy(self.defaults) 98 local_args.update(config) 99 weight = self._extract_weight(data) 100 101 # Bookkeeping in the container class 102 mask = local_args.get("mask", torch.ones_like(weight)) 103 param_class = local_args.get("parametrization", utils.FakeSparsity) 104 105 if name in self.state: 106 # If the named data already exists - replace 107 warnings.warn( 108 "Replacing existing data of the same name. - Did you mean a different name?" 109 ) 110 111 # reuse old config 112 old_args = self.data_groups[name] 113 local_args = copy.deepcopy(old_args) 114 local_args.update(config) 115 116 if reuse_mask: 117 current_data = self.get_data(name=name) 118 assert ( 119 weight.shape == current_data.shape 120 ), "to retain the old mask, the shape of the new data must be the same as the previous one" 121 mask = self.get_mask( 122 name=name 123 ) # reuse mask instead of creating a new one 124 125 self._delete_data(name=name) 126 127 # parameter creates a deepcopy of the weight inside, so create a buffer 128 self._container.register_buffer(name=name, tensor=weight) 129 parametrize.register_parametrization(self._container, name, param_class(mask)) 130 self.state[name]["mask"] = mask 131 self.data_groups[name] = local_args 132 return getattr(self._container, name) 133 134 def get_data(self, name: str, return_original: bool = True): 135 r"""Returns weight tensor (or data) 136 Args: 137 - name: name of the data to be returned 138 - return_original returns weight tensor without applying parametrization if True 139 else - returns the sparsified version (parametrized) 140 """ 141 if name not in self.data_groups: 142 raise ValueError("data with specified name does not exist") 143 144 if return_original: 145 if not parametrize.is_parametrized(self._container, name): 146 raise ValueError("mask squashed - original mask value does not exist") 147 data = getattr(self._container.parametrizations, name).original 148 return data 149 else: 150 return getattr(self._container, name) 151 152 def _convert_mask(self, states, sparse_coo=True): 153 r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument.""" 154 states = copy.deepcopy(states) 155 for state in states.values(): 156 if sparse_coo: 157 state["mask"] = state["mask"].to_sparse_coo() 158 else: 159 state["mask"] = state["mask"].to_dense() 160 161 return states 162 163 def state_dict(self): 164 r"""Returns the state of the optimizer as a :class:`dict`. 165 166 It contains: 167 * state - contains name -> mask mapping. 168 * data_groups - a list containing all sparsity configuration groups 169 with the key name specifying the name of the data 170 * container_state_dict - the state dictionary of the internal 171 container model used for sparsification 172 """ 173 state = self._convert_mask(self.state) 174 return { 175 "state": state, 176 "data_groups": self.data_groups, 177 "_container": self._container.state_dict(), 178 } 179 180 def _load_container_from_state(self, states, data_groups, container_state_dict): 181 r"""This restores the state of the container specifically based on the data present in state and data_groups 182 If the data was parametrized, then the data would be added to the container and then parametrized, 183 else it would just add the attribute the container. 184 """ 185 for name, state in states.items(): 186 config_name = data_groups.get(name, None) 187 if config_name is None: 188 raise RuntimeError(f"Error loading {name}") 189 190 # check if the data with such a name was parametrized, if so parametrize 191 # otherwise just set the attribute and continue 192 parametrized_name = f"parametrizations.{name}.original" 193 parametrized = False 194 data = container_state_dict.get(name, None) 195 if name in container_state_dict: 196 # the parametrization was probably removed for this 197 data = container_state_dict.get(name) 198 199 elif parametrized_name in container_state_dict: 200 # so the weight was parametrized 201 data = container_state_dict.get(parametrized_name) 202 parametrized = True 203 204 else: 205 raise RuntimeError(f"Error loading {name}") 206 207 self._container.register_buffer(name=name, tensor=data) 208 209 if parametrized: 210 # register parameter if parametrized 211 mask = state.get("mask", torch.ones_like(data)) 212 param_class = data_groups.get( 213 "parametrization", utils.FakeSparsity 214 ) # change once public_api for utils is fixed! 215 parametrize.register_parametrization( 216 self._container, name, param_class(mask) 217 ) 218 219 def load_state_dict(self, state_dict, strict=True): 220 r"""The load_state_dict() restores the state of the sparsifier based on the state_dict 221 222 Args: 223 * state_dict - the dictionary that to which the current sparsifier needs to be restored to 224 * strict - If True - the sparsifier is reset and is restored exactly to the state in state_dict. 225 If False - the current sparsifier is not reset before loading the state_dict i.e. data added 226 before loading the state_dict is not erased. 227 """ 228 states = copy.deepcopy(state_dict["state"]) 229 data_groups = copy.deepcopy(state_dict["data_groups"]) 230 container_state_dict = copy.deepcopy(state_dict["_container"]) 231 232 states = self._convert_mask( 233 states, sparse_coo=False 234 ) # convert sparse coo mask to dense 235 if strict: 236 # if strict load -> then reset container 237 self._container = _Container() 238 239 self._load_container_from_state(states, data_groups, container_state_dict) 240 241 if not strict: 242 states.update(self.state) 243 data_groups.update(self.data_groups) 244 245 self.__setstate__({"state": states, "data_groups": data_groups}) 246 247 def __setstate__(self, state): 248 if "_container" in state: # If container object is in state then load model 249 container_dict = state.pop("_container") 250 self._container = _Container() 251 state["state"] = self._convert_mask( 252 state["state"], sparse_coo=False 253 ) # convert sparse coo mask to dense 254 self._load_container_from_state( 255 state["state"], state["data_groups"], container_dict 256 ) 257 258 self.__dict__.update(state) 259 260 def __getstate__(self): 261 state = self._convert_mask(self.state) 262 return { 263 "defaults": self.defaults, 264 "state": state, 265 "data_groups": self.data_groups, 266 "_container": self._container.state_dict(), 267 } 268 269 def __repr__(self): 270 format_string = self.__class__.__name__ + " (" 271 for name, sparse_args in self.data_groups.items(): 272 format_string += "\n" 273 format_string += "\tData Group\n" 274 format_string += f"\t name: {name}\n" 275 for key in sorted(sparse_args.keys()): 276 if key == "data": 277 continue 278 format_string += f"\t {key}: {sparse_args[key]}\n" 279 format_string += ")" 280 return format_string 281 282 def get_mask(self, name: str): 283 if name not in self.state: 284 raise ValueError("data with specified name does not exist") 285 return self.state[name]["mask"] 286 287 def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): 288 r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings 289 to squash mask for. If none, squashes mask for all the keys 290 kwargs: 291 * names: list of strings to squash mask for 292 * sparsified: if true - applies the mask before squashing 293 if false - does not apply the mask before squashing 294 """ 295 if names is None: 296 names = list(self.data_groups.keys()) 297 for name in names: 298 parametrize.remove_parametrizations( 299 self._container, name, leave_parametrized=leave_parametrized 300 ) 301 302 def step(self): 303 if not self.enable_mask_update: 304 return 305 with torch.no_grad(): 306 for name, config in self.data_groups.items(): 307 # get non-sparsified data 308 data = self.get_data(name) 309 # need name for the mask otherwise can directly pass mask? 310 self.update_mask(name, data, **config) 311 312 @abc.abstractmethod 313 def update_mask(self, name, data, **kwargs): 314 pass 315 316 def _delete_data(self, name): 317 """Detaches some data from the sparsifier. 318 319 Args: 320 name (str) 321 Name of the data to be removed from the sparsifier 322 323 Note: 324 Currently private. Kind of used as a helper function when replacing data of the same name 325 """ 326 self.squash_mask( 327 names=[name], leave_parametrized=False 328 ) # do not apply the mask while deleting 329 delattr(self._container, name) 330 self.state.pop(name) 331 self.data_groups.pop(name) 332