1# mypy: allow-untyped-defs 2""" 3Collection of conversion functions for linear / conv2d structured pruning 4Also contains utilities for bias propagation 5""" 6from typing import Callable, cast, List, Optional, Tuple 7 8import torch 9from torch import nn, Tensor 10from torch.nn.utils import parametrize 11from torch.nn.utils.parametrize import ParametrizationList 12 13from .parametrization import BiasHook, FakeStructuredSparsity 14 15 16# BIAS PROPAGATION 17def _remove_bias_handles(module: nn.Module) -> None: 18 if hasattr(module, "_forward_hooks"): 19 bias_hooks: List[int] = [] 20 for key, hook in module._forward_hooks.items(): 21 if isinstance(hook, BiasHook): 22 bias_hooks.append(key) 23 24 for key in bias_hooks: 25 del module._forward_hooks[key] 26 27 28def _get_adjusted_next_layer_bias( 29 next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor 30) -> nn.Parameter: 31 r"""Returns new adjusted bias for the second supported module""" 32 if parametrize.is_parametrized(next_layer): 33 # need to access original weight 34 parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations) 35 weight_parameterizations = cast( 36 ParametrizationList, parametrization_dict.weight 37 ) 38 next_weight = weight_parameterizations.original 39 else: 40 next_weight = cast(Tensor, next_layer.weight) 41 42 scaling_weight = next_weight[:, ~mask] 43 if isinstance(next_layer, nn.Conv2d): # checking for Conv2d 44 # Propagating first layer pruned biases and calculating the new second layer bias 45 # involves more steps since the Conv2d scaling weight has extra dimensions, 46 # so adding bias involves broadcasting, logically: 47 # for each channel k in range(oC): 48 # scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T) 49 # new_next_bias[k] = old_next_bias[k] + scaled_biases 50 scaling_product = torch.matmul( 51 pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2) 52 ) 53 sum_range = list(range(len(scaling_product.shape)))[ 54 1: 55 ] # all but the first dimension 56 scaled_biases = torch.sum(scaling_product, sum_range) 57 elif isinstance(next_layer, nn.Linear): # Linear 58 scaled_biases = torch.matmul( 59 pruned_biases, torch.transpose(scaling_weight, 0, 1) 60 ) # recall b2_new = b1 @ w2.T + b2 61 else: 62 raise NotImplementedError(f"Type {type(next_layer)} not supported yet.") 63 64 if ( 65 parametrize.is_parametrized(next_layer) 66 and getattr(next_layer, "_bias", None) is not None 67 ): # next_layer is parametrized & has original bias ._bias 68 adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias) 69 elif ( 70 not parametrize.is_parametrized(next_layer) and next_layer.bias is not None 71 ): # next_layer not parametrized & has .bias 72 adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias) 73 else: # next_layer has no bias 74 adjusted_bias = nn.Parameter(scaled_biases) 75 return adjusted_bias 76 77 78def _prune_module_bias(module: nn.Module, mask: Tensor) -> None: 79 r"""Applies mask to given modules bias""" 80 # prune bias along with weights, discard pruned indices of bias 81 original_bias = cast(Tensor, getattr(module, "_bias", module.bias)) 82 if original_bias is not None: 83 module.bias = nn.Parameter(original_bias[mask]) 84 85 # remove _bias parameter 86 if hasattr(module, "_bias"): 87 delattr(module, "_bias") 88 89 90def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]: 91 r""" 92 In the case that we need to propagate biases, this function will return the biases we need 93 """ 94 # set current module bias 95 if module.bias is not None: 96 module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) 97 elif getattr(module, "_bias", None) is not None: 98 module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) 99 100 # get pruned biases to propagate to subsequent layer 101 if getattr(module, "_bias", None) is not None: 102 pruned_biases = cast(Tensor, module._bias)[~mask] 103 else: 104 pruned_biases = None 105 106 if hasattr(module, "_bias"): 107 delattr(module, "_bias") 108 109 return pruned_biases 110 111 112# LINEAR 113def _prune_linear_helper(linear: nn.Linear) -> Tensor: 114 # expects linear to be a parameterized linear module 115 parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) 116 weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) 117 for p in weight_parameterizations: 118 if isinstance(p, FakeStructuredSparsity): 119 mask = cast(Tensor, p.mask) 120 121 with torch.no_grad(): 122 parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) 123 linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined] 124 linear.out_features = linear.weight.shape[0] 125 _remove_bias_handles(linear) 126 127 return mask 128 129 130def prune_linear(linear: nn.Linear) -> None: 131 mask = _prune_linear_helper(linear) 132 if getattr(linear, "prune_bias", False): 133 _prune_module_bias(linear, mask) 134 135 136def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None: 137 prune_linear_activation_linear(linear1, None, linear2) 138 139 140def prune_linear_activation_linear( 141 linear1: nn.Linear, 142 activation: Optional[Callable[[Tensor], Tensor]], 143 linear2: nn.Linear, 144): 145 mask = _prune_linear_helper(linear1) 146 if getattr(linear1, "prune_bias", False): 147 _prune_module_bias(linear1, mask) 148 else: 149 pruned_biases = _propagate_module_bias(linear1, mask) 150 if pruned_biases is not None: 151 if activation: 152 pruned_biases = activation(pruned_biases) 153 linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask) 154 155 with torch.no_grad(): 156 if parametrize.is_parametrized(linear2): 157 parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations) 158 weight_parameterizations = cast( 159 ParametrizationList, parametrization_dict.weight 160 ) 161 162 weight_parameterizations.original = nn.Parameter( 163 weight_parameterizations.original[:, mask] 164 ) 165 linear2.in_features = weight_parameterizations.original.shape[1] 166 else: 167 linear2.weight = nn.Parameter(linear2.weight[:, mask]) 168 linear2.in_features = linear2.weight.shape[1] 169 170 171# CONV2D 172def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: 173 parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations) 174 weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) 175 for p in weight_parameterizations: 176 if isinstance(p, FakeStructuredSparsity): 177 mask = cast(Tensor, p.mask) 178 179 with torch.no_grad(): 180 parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) 181 conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined] 182 conv2d.out_channels = conv2d.weight.shape[0] 183 184 _remove_bias_handles(conv2d) 185 return mask 186 187 188def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: 189 parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) 190 weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) 191 for p in weight_parameterizations: 192 if isinstance(p, FakeStructuredSparsity): 193 mask = cast(Tensor, p.mask) 194 195 with torch.no_grad(): 196 parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True) 197 198 if getattr(conv2d_1, "_bias", None) is not None: 199 if ( 200 conv2d_1.bias is not None 201 ): # conv2d_1 has original bias and bias propagated from previous layer 202 new_bias = torch.zeros(conv2d_1.bias.shape) 203 new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] 204 # adjusted bias that to keep in conv2d_1 205 new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] 206 # pruned biases that are kept instead of propagated 207 conv2d_1.bias = nn.Parameter(new_bias) 208 else: # conv2d_1 has only original bias 209 conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias)) 210 else: 211 # no original bias, only propagated bias 212 if ( 213 conv2d_1.bias is not None 214 ): # conv2d_1 has bias propagated from previous layer 215 conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined] 216 217 if hasattr(conv2d_1, "_bias"): 218 delattr(conv2d_1, "_bias") 219 220 221def prune_conv2d(conv2d: nn.Conv2d) -> None: 222 mask = _prune_conv2d_helper(conv2d) 223 if getattr(conv2d, "prune_bias", False): 224 _prune_module_bias(conv2d, mask) 225 226 227def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None: 228 prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2) 229 230 231def prune_conv2d_activation_conv2d( 232 conv2d_1: nn.Conv2d, 233 activation: Optional[Callable[[Tensor], Tensor]], 234 conv2d_2: nn.Conv2d, 235): 236 r""" 237 Fusion Pattern for conv2d -> some activation module / function -> conv2d layers 238 """ 239 parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) 240 weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) 241 for p in weight_parameterizations: 242 if isinstance(p, FakeStructuredSparsity): 243 mask = cast(Tensor, p.mask) 244 245 prune_bias = getattr(conv2d_1, "prune_bias", False) 246 if ( 247 hasattr(conv2d_2, "padding") 248 and cast(Tuple[int], conv2d_2.padding) > (0, 0) 249 and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None) 250 ): 251 prune_conv2d_padded(conv2d_1) 252 else: 253 mask = _prune_conv2d_helper(conv2d_1) 254 if prune_bias: 255 _prune_module_bias(conv2d_1, mask) 256 else: 257 pruned_biases = _propagate_module_bias(conv2d_1, mask) 258 if pruned_biases is not None: 259 if activation: 260 pruned_biases = activation(pruned_biases) 261 conv2d_2.bias = _get_adjusted_next_layer_bias( 262 conv2d_2, pruned_biases, mask 263 ) 264 265 if ( 266 not ( 267 hasattr(conv2d_2, "padding") 268 and cast(Tuple[int], conv2d_2.padding) > (0, 0) 269 ) 270 or conv2d_1.bias is None 271 ): 272 with torch.no_grad(): 273 if parametrize.is_parametrized(conv2d_2): 274 parametrization_dict = cast( 275 nn.ModuleDict, conv2d_2.parametrizations 276 ) 277 weight_parameterizations = cast( 278 ParametrizationList, parametrization_dict.weight 279 ) 280 weight_parameterizations.original = nn.Parameter( 281 weight_parameterizations.original[:, mask] 282 ) 283 conv2d_2.in_channels = weight_parameterizations.original.shape[1] 284 else: 285 conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask]) 286 conv2d_2.in_channels = conv2d_2.weight.shape[1] 287 288 289def prune_conv2d_pool_activation_conv2d( 290 c1: nn.Conv2d, 291 pool: nn.Module, 292 activation: Optional[Callable[[Tensor], Tensor]], 293 c2: nn.Conv2d, 294) -> None: 295 prune_conv2d_activation_conv2d(c1, activation, c2) 296 297 298def prune_conv2d_activation_pool_conv2d( 299 c1: nn.Conv2d, 300 activation: Optional[Callable[[Tensor], Tensor]], 301 pool: nn.Module, 302 c2: nn.Conv2d, 303) -> None: 304 prune_conv2d_activation_conv2d(c1, activation, c2) 305 306 307def prune_conv2d_pool_flatten_linear( 308 conv2d: nn.Conv2d, 309 pool: nn.Module, 310 flatten: Optional[Callable[[Tensor], Tensor]], 311 linear: nn.Linear, 312) -> None: 313 mask = _prune_conv2d_helper(conv2d) 314 315 # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer. 316 # we determine the flattening scale (h * w), and readjust `first_pruned_indices` 317 # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`, 318 # and `pruned_biases` (repeat each bias by h * w). 319 if parametrize.is_parametrized(linear): 320 parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) 321 weight_parameterizations = cast( 322 ParametrizationList, parametrization_dict.weight 323 ) 324 linear_ic = weight_parameterizations.original.shape[1] 325 else: 326 linear_ic = linear.weight.shape[1] 327 328 conv2d_oc = len(mask) 329 assert ( 330 linear_ic % conv2d_oc == 0 331 ), f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" 332 333 flatten_scale = linear_ic // conv2d_oc 334 flattened_mask = torch.tensor( 335 [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device 336 ).flatten() 337 338 if getattr(conv2d, "prune_bias", False): 339 _prune_module_bias(conv2d, mask) 340 else: 341 pruned_biases = cast(Tensor, _propagate_module_bias(conv2d, mask)) 342 flattened_pruned_biases = torch.tensor( 343 [[bias] * flatten_scale for bias in pruned_biases], device=mask.device 344 ).flatten() 345 linear.bias = _get_adjusted_next_layer_bias( 346 linear, flattened_pruned_biases, flattened_mask 347 ) 348 349 with torch.no_grad(): 350 if parametrize.is_parametrized(linear): 351 parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) 352 weight_parameterizations = cast( 353 ParametrizationList, parametrization_dict.weight 354 ) 355 weight_parameterizations.original = nn.Parameter( 356 weight_parameterizations.original[:, flattened_mask] 357 ) 358 linear.in_features = weight_parameterizations.original.shape[1] 359 else: 360 linear.weight = nn.Parameter(linear.weight[:, flattened_mask]) 361 linear.in_features = linear.weight.shape[1] 362 363 364def prune_lstm_output_linear( 365 lstm: nn.LSTM, getitem: Callable, linear: nn.Linear 366) -> None: 367 prune_lstm_output_layernorm_linear(lstm, getitem, None, linear) 368 369 370def prune_lstm_output_layernorm_linear( 371 lstm: nn.LSTM, 372 getitem: Callable, 373 layernorm: Optional[nn.LayerNorm], 374 linear: nn.Linear, 375) -> None: 376 for i in range(lstm.num_layers): 377 if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"): 378 parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) 379 weight_parameterizations = cast( 380 ParametrizationList, parametrization_dict[f"weight_ih_l{i}"] 381 ) 382 mask = weight_parameterizations[0].mask 383 384 with torch.no_grad(): 385 parametrize.remove_parametrizations( 386 lstm, f"weight_ih_l{i}", leave_parametrized=True 387 ) 388 setattr( 389 lstm, 390 f"weight_ih_l{i}", 391 nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]), 392 ) 393 setattr( 394 lstm, 395 f"bias_ih_l{i}", 396 nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]), 397 ) 398 399 if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"): 400 parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) 401 weight_parameterizations = cast( 402 ParametrizationList, parametrization_dict[f"weight_hh_l{i}"] 403 ) 404 mask = weight_parameterizations[0].mask 405 406 with torch.no_grad(): 407 parametrize.remove_parametrizations( 408 lstm, f"weight_hh_l{i}", leave_parametrized=True 409 ) 410 # splitting out hidden-hidden masks 411 W_hi, W_hf, W_hg, W_ho = torch.split( 412 getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size 413 ) 414 M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size) 415 416 # resize each individual weight separately 417 W_hi = W_hi[M_hi][:, M_hi] 418 W_hf = W_hf[M_hf][:, M_hf] 419 W_hg = W_hg[M_hg][:, M_hg] 420 W_ho = W_ho[M_ho][:, M_ho] 421 422 # concat, use this as new weight 423 new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho)) 424 setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight)) 425 setattr( 426 lstm, 427 f"bias_hh_l{i}", 428 nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]), 429 ) 430 431 # If this is the final layer, then we need to prune linear layer columns 432 if i + 1 == lstm.num_layers: 433 lstm.hidden_size = int(M_hi.sum()) 434 with torch.no_grad(): 435 if parametrize.is_parametrized(linear): 436 parametrization_dict = cast( 437 nn.ModuleDict, linear.parametrizations 438 ) 439 weight_parameterizations = cast( 440 ParametrizationList, parametrization_dict.weight 441 ) 442 443 weight_parameterizations.original = nn.Parameter( 444 weight_parameterizations.original[:, M_ho] 445 ) 446 linear.in_features = weight_parameterizations.original.shape[1] 447 else: 448 linear.weight = nn.Parameter(linear.weight[:, M_ho]) 449 linear.in_features = linear.weight.shape[1] 450 451 # if layernorm module, prune weight and bias 452 if layernorm is not None: 453 layernorm.normalized_shape = (linear.in_features,) 454 layernorm.weight = nn.Parameter(layernorm.weight[M_ho]) 455 layernorm.bias = nn.Parameter(layernorm.bias[M_ho]) 456 457 # otherwise need to prune the columns of the input of the next LSTM layer 458 else: 459 with torch.no_grad(): 460 if parametrize.is_parametrized(lstm, f"weight_ih_l{i + 1}"): 461 parametrization_dict = cast( 462 nn.ModuleDict, lstm.parametrizations 463 ) 464 weight_parameterizations = cast( 465 ParametrizationList, 466 getattr(parametrization_dict, f"weight_ih_l{i + 1}"), 467 ) 468 469 weight_parameterizations.original = nn.Parameter( 470 weight_parameterizations.original[:, M_ho] 471 ) 472 else: 473 next_layer_weight = getattr(lstm, f"weight_ih_l{i + 1}") 474 setattr( 475 lstm, 476 f"weight_ih_l{i + 1}", 477 nn.Parameter(next_layer_weight[:, M_ho]), 478 ) 479