1# mypy: allow-untyped-defs 2from itertools import chain 3from operator import getitem 4from typing import Callable, Dict, Optional, Set, Tuple, Type, Union 5 6import torch 7import torch.nn.functional as F 8from torch import nn 9from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier 10from torch.fx import symbolic_trace 11from torch.nn.utils import parametrize 12 13from .match_utils import apply_match, MatchAllNode 14from .parametrization import BiasHook, FakeStructuredSparsity, module_contains_param 15from .prune_functions import ( 16 prune_conv2d, 17 prune_conv2d_activation_conv2d, 18 prune_conv2d_activation_pool_conv2d, 19 prune_conv2d_conv2d, 20 prune_conv2d_pool_activation_conv2d, 21 prune_conv2d_pool_flatten_linear, 22 prune_linear, 23 prune_linear_activation_linear, 24 prune_linear_linear, 25 prune_lstm_output_layernorm_linear, 26 prune_lstm_output_linear, 27) 28 29 30def _get_supported_structured_pruning_modules(): 31 SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given 32 nn.Linear, 33 nn.Conv2d, 34 nn.LSTM, 35 } 36 return SUPPORTED_STRUCTURED_PRUNING_MODULES 37 38 39def _get_supported_activation_functions(): 40 SUPPORTED_ACTIVATION_FUNCTIONS = { 41 F.relu, 42 F.rrelu, 43 F.hardtanh, 44 F.relu6, 45 F.sigmoid, 46 F.hardsigmoid, 47 F.tanh, 48 F.silu, 49 F.mish, 50 F.hardswish, 51 F.elu, 52 F.celu, 53 F.selu, 54 F.hardshrink, 55 F.leaky_relu, 56 F.logsigmoid, 57 F.softplus, 58 F.prelu, 59 F.softsign, 60 F.tanhshrink, 61 F.gelu, 62 } 63 return SUPPORTED_ACTIVATION_FUNCTIONS 64 65 66def _get_supported_activation_modules(): 67 SUPPORTED_ACTIVATION_MODULES = { 68 nn.ReLU, 69 nn.RReLU, 70 nn.Hardtanh, 71 nn.ReLU6, 72 nn.Sigmoid, 73 nn.Hardsigmoid, 74 nn.Tanh, 75 nn.SiLU, 76 nn.Mish, 77 nn.Hardswish, 78 nn.ELU, 79 nn.CELU, 80 nn.SELU, 81 nn.Hardshrink, 82 nn.LeakyReLU, 83 nn.LogSigmoid, 84 nn.Softplus, 85 nn.PReLU, 86 nn.Softsign, 87 nn.Tanhshrink, 88 nn.GELU, 89 } 90 return SUPPORTED_ACTIVATION_MODULES 91 92 93def _get_default_structured_pruning_patterns() -> ( 94 Dict[ 95 Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], 96 Callable[..., None], 97 ] 98): 99 """ 100 Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. 101 """ 102 patterns: Dict[ 103 Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], 104 Callable[..., None], 105 ] = { 106 # linear -> linear 107 (nn.Linear, "output"): prune_linear, 108 (nn.Linear, nn.Linear): prune_linear_linear, 109 # conv2d -> conv2d 110 (nn.Conv2d, "output"): prune_conv2d, 111 (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, 112 # TODO LSTM Structured pruning does not support returned state currently. 113 # Should find a way to explicitly match getitem(0) instead of getitem. 114 # This will also require changing the pruning function. 115 # lstm -> getitem(0) -> linear 116 (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, 117 # lstm -> getitem(0) -> layernorm -> linear 118 (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, 119 } 120 121 for activation in chain( 122 _get_supported_activation_functions(), _get_supported_activation_modules() 123 ): 124 patterns.update( 125 { 126 # linear -> activation -> linear 127 (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, 128 # conv2d -> activation -> conv2d 129 (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, 130 # conv2d -> activation -> pool -> conv2d 131 ( 132 nn.Conv2d, 133 activation, 134 nn.AvgPool2d, 135 nn.Conv2d, 136 ): prune_conv2d_activation_pool_conv2d, 137 ( 138 nn.Conv2d, 139 activation, 140 F.avg_pool2d, 141 nn.Conv2d, 142 ): prune_conv2d_activation_pool_conv2d, 143 ( 144 nn.Conv2d, 145 activation, 146 nn.MaxPool2d, 147 nn.Conv2d, 148 ): prune_conv2d_activation_pool_conv2d, 149 ( 150 nn.Conv2d, 151 activation, 152 F.max_pool2d, 153 nn.Conv2d, 154 ): prune_conv2d_activation_pool_conv2d, 155 # conv2d -> pool -> activation -> conv2d 156 ( 157 nn.Conv2d, 158 nn.AvgPool2d, 159 activation, 160 nn.Conv2d, 161 ): prune_conv2d_pool_activation_conv2d, 162 ( 163 nn.Conv2d, 164 F.avg_pool2d, 165 activation, 166 nn.Conv2d, 167 ): prune_conv2d_pool_activation_conv2d, 168 ( 169 nn.Conv2d, 170 nn.MaxPool2d, 171 activation, 172 nn.Conv2d, 173 ): prune_conv2d_pool_activation_conv2d, 174 ( 175 nn.Conv2d, 176 F.max_pool2d, 177 activation, 178 nn.Conv2d, 179 ): prune_conv2d_pool_activation_conv2d, 180 # conv2d -> adaptive pool -> flatten -> linear 181 ( 182 nn.Conv2d, 183 nn.AdaptiveAvgPool2d, 184 nn.Flatten, 185 nn.Linear, 186 ): prune_conv2d_pool_flatten_linear, 187 ( 188 nn.Conv2d, 189 nn.AdaptiveAvgPool2d, 190 torch.flatten, 191 nn.Linear, 192 ): prune_conv2d_pool_flatten_linear, 193 ( 194 nn.Conv2d, 195 nn.AdaptiveMaxPool2d, 196 nn.Flatten, 197 nn.Linear, 198 ): prune_conv2d_pool_flatten_linear, 199 ( 200 nn.Conv2d, 201 nn.AdaptiveMaxPool2d, 202 torch.flatten, 203 nn.Linear, 204 ): prune_conv2d_pool_flatten_linear, 205 } 206 ) 207 return patterns 208 209 210class BaseStructuredSparsifier(BaseSparsifier): 211 r"""Base class for structured pruning. 212 213 Abstract methods that need to be implemented: 214 - update_mask: Function to compute a new mask for all keys in the 215 `groups` attribute. 216 217 Args: 218 - defaults [dict]: default configurations will be attached to the 219 configuration. Only the keys that don't exist in the `config` will 220 be updated. 221 """ 222 223 def __init__(self, defaults, patterns=None): 224 super().__init__(defaults) 225 if patterns is None: 226 patterns = _get_default_structured_pruning_patterns() 227 self.patterns = patterns 228 229 def make_config_from_model( 230 self, 231 model: nn.Module, 232 SUPPORTED_MODULES: Optional[Set[Type]] = None, 233 ) -> None: 234 if SUPPORTED_MODULES is None: 235 SUPPORTED_MODULES = _get_supported_structured_pruning_modules() 236 super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) 237 238 def _prepare(self, *args, **kwargs) -> None: 239 r"""This function will attach the FakeStructuredSparsity parameterizations 240 and BiasHooks at the appropriate points in the model. 241 """ 242 for config in self.groups: 243 module = config["module"] 244 tensor_name = config["tensor_name"] 245 parametrization = config.get("parametrization", FakeStructuredSparsity) 246 tensor = getattr(module, tensor_name) 247 248 mask = config.get( 249 "mask", 250 torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), 251 ) 252 self.state[config["tensor_fqn"]]["mask"] = mask 253 parametrize.register_parametrization( 254 module, tensor_name, parametrization(mask) 255 ) 256 257 # if linear / conv, we add in bias hooks 258 if isinstance(module, (nn.Linear, nn.Conv2d)): 259 prune_bias = config.get("prune_bias", True) 260 if module.bias is not None: 261 module.register_parameter( 262 "_bias", nn.Parameter(module.bias.detach()) 263 ) 264 module.bias = None 265 module.prune_bias = prune_bias 266 267 module.register_forward_hook( 268 BiasHook(module.parametrizations.weight[0], prune_bias) 269 ) 270 271 def prune(self) -> None: 272 r""" 273 This function will FX symbolically trace the model and then find instances of the patterns 274 defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). 275 276 For each pattern, it will apply to corresponding conversion function, which will modify the output 277 and input size expected by the modules within the pattern 278 """ 279 280 self.traced = symbolic_trace(self.model) 281 modules = dict(self.traced.named_modules()) 282 283 # Right now we check for matches simply by iterating across all the patterns 284 # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup 285 for node in self.traced.graph.nodes: 286 for pattern, convert_fn in self.patterns.items(): 287 matched = apply_match(modules, pattern, node, []) 288 if matched is None: 289 continue 290 291 first_module = modules.get(node.target) 292 # check if first module exists and has appropriate parameterization, otherwise skip 293 if ( 294 first_module is not None 295 and parametrize.is_parametrized(first_module) 296 and module_contains_param(first_module, FakeStructuredSparsity) 297 ): 298 convert_block = [] 299 for node in matched: 300 if node.op == "call_module": 301 convert_block.append(modules.get(node.target)) 302 elif node.op == "call_function": 303 convert_block.append(node.target) 304 convert_fn(*convert_block) 305 306 for module in self.traced.modules(): 307 if module_contains_param(module, FakeStructuredSparsity): 308 raise Exception( # noqa: TRY002 309 f"Error: {module} still contains FakeStructuredSparsity parametrizations!" 310 ) 311 312 self.traced.graph.lint() 313 self.traced.recompile() 314 return self.traced # type: ignore[return-value] 315