xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from itertools import chain
3from operator import getitem
4from typing import Callable, Dict, Optional, Set, Tuple, Type, Union
5
6import torch
7import torch.nn.functional as F
8from torch import nn
9from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier
10from torch.fx import symbolic_trace
11from torch.nn.utils import parametrize
12
13from .match_utils import apply_match, MatchAllNode
14from .parametrization import BiasHook, FakeStructuredSparsity, module_contains_param
15from .prune_functions import (
16    prune_conv2d,
17    prune_conv2d_activation_conv2d,
18    prune_conv2d_activation_pool_conv2d,
19    prune_conv2d_conv2d,
20    prune_conv2d_pool_activation_conv2d,
21    prune_conv2d_pool_flatten_linear,
22    prune_linear,
23    prune_linear_activation_linear,
24    prune_linear_linear,
25    prune_lstm_output_layernorm_linear,
26    prune_lstm_output_linear,
27)
28
29
30def _get_supported_structured_pruning_modules():
31    SUPPORTED_STRUCTURED_PRUNING_MODULES = {  # added to config if None given
32        nn.Linear,
33        nn.Conv2d,
34        nn.LSTM,
35    }
36    return SUPPORTED_STRUCTURED_PRUNING_MODULES
37
38
39def _get_supported_activation_functions():
40    SUPPORTED_ACTIVATION_FUNCTIONS = {
41        F.relu,
42        F.rrelu,
43        F.hardtanh,
44        F.relu6,
45        F.sigmoid,
46        F.hardsigmoid,
47        F.tanh,
48        F.silu,
49        F.mish,
50        F.hardswish,
51        F.elu,
52        F.celu,
53        F.selu,
54        F.hardshrink,
55        F.leaky_relu,
56        F.logsigmoid,
57        F.softplus,
58        F.prelu,
59        F.softsign,
60        F.tanhshrink,
61        F.gelu,
62    }
63    return SUPPORTED_ACTIVATION_FUNCTIONS
64
65
66def _get_supported_activation_modules():
67    SUPPORTED_ACTIVATION_MODULES = {
68        nn.ReLU,
69        nn.RReLU,
70        nn.Hardtanh,
71        nn.ReLU6,
72        nn.Sigmoid,
73        nn.Hardsigmoid,
74        nn.Tanh,
75        nn.SiLU,
76        nn.Mish,
77        nn.Hardswish,
78        nn.ELU,
79        nn.CELU,
80        nn.SELU,
81        nn.Hardshrink,
82        nn.LeakyReLU,
83        nn.LogSigmoid,
84        nn.Softplus,
85        nn.PReLU,
86        nn.Softsign,
87        nn.Tanhshrink,
88        nn.GELU,
89    }
90    return SUPPORTED_ACTIVATION_MODULES
91
92
93def _get_default_structured_pruning_patterns() -> (
94    Dict[
95        Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...],
96        Callable[..., None],
97    ]
98):
99    """
100    Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above.
101    """
102    patterns: Dict[
103        Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...],
104        Callable[..., None],
105    ] = {
106        # linear -> linear
107        (nn.Linear, "output"): prune_linear,
108        (nn.Linear, nn.Linear): prune_linear_linear,
109        # conv2d -> conv2d
110        (nn.Conv2d, "output"): prune_conv2d,
111        (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d,
112        # TODO LSTM Structured pruning does not support returned state currently.
113        # Should find a way to explicitly match getitem(0) instead of getitem.
114        # This will also require changing the pruning function.
115        # lstm -> getitem(0) -> linear
116        (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear,
117        # lstm -> getitem(0) -> layernorm -> linear
118        (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear,
119    }
120
121    for activation in chain(
122        _get_supported_activation_functions(), _get_supported_activation_modules()
123    ):
124        patterns.update(
125            {
126                # linear -> activation -> linear
127                (nn.Linear, activation, nn.Linear): prune_linear_activation_linear,
128                # conv2d -> activation -> conv2d
129                (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d,
130                # conv2d -> activation -> pool -> conv2d
131                (
132                    nn.Conv2d,
133                    activation,
134                    nn.AvgPool2d,
135                    nn.Conv2d,
136                ): prune_conv2d_activation_pool_conv2d,
137                (
138                    nn.Conv2d,
139                    activation,
140                    F.avg_pool2d,
141                    nn.Conv2d,
142                ): prune_conv2d_activation_pool_conv2d,
143                (
144                    nn.Conv2d,
145                    activation,
146                    nn.MaxPool2d,
147                    nn.Conv2d,
148                ): prune_conv2d_activation_pool_conv2d,
149                (
150                    nn.Conv2d,
151                    activation,
152                    F.max_pool2d,
153                    nn.Conv2d,
154                ): prune_conv2d_activation_pool_conv2d,
155                # conv2d -> pool -> activation -> conv2d
156                (
157                    nn.Conv2d,
158                    nn.AvgPool2d,
159                    activation,
160                    nn.Conv2d,
161                ): prune_conv2d_pool_activation_conv2d,
162                (
163                    nn.Conv2d,
164                    F.avg_pool2d,
165                    activation,
166                    nn.Conv2d,
167                ): prune_conv2d_pool_activation_conv2d,
168                (
169                    nn.Conv2d,
170                    nn.MaxPool2d,
171                    activation,
172                    nn.Conv2d,
173                ): prune_conv2d_pool_activation_conv2d,
174                (
175                    nn.Conv2d,
176                    F.max_pool2d,
177                    activation,
178                    nn.Conv2d,
179                ): prune_conv2d_pool_activation_conv2d,
180                # conv2d -> adaptive pool -> flatten -> linear
181                (
182                    nn.Conv2d,
183                    nn.AdaptiveAvgPool2d,
184                    nn.Flatten,
185                    nn.Linear,
186                ): prune_conv2d_pool_flatten_linear,
187                (
188                    nn.Conv2d,
189                    nn.AdaptiveAvgPool2d,
190                    torch.flatten,
191                    nn.Linear,
192                ): prune_conv2d_pool_flatten_linear,
193                (
194                    nn.Conv2d,
195                    nn.AdaptiveMaxPool2d,
196                    nn.Flatten,
197                    nn.Linear,
198                ): prune_conv2d_pool_flatten_linear,
199                (
200                    nn.Conv2d,
201                    nn.AdaptiveMaxPool2d,
202                    torch.flatten,
203                    nn.Linear,
204                ): prune_conv2d_pool_flatten_linear,
205            }
206        )
207    return patterns
208
209
210class BaseStructuredSparsifier(BaseSparsifier):
211    r"""Base class for structured pruning.
212
213    Abstract methods that need to be implemented:
214        - update_mask: Function to compute a new mask for all keys in the
215            `groups` attribute.
216
217    Args:
218        - defaults [dict]: default configurations will be attached to the
219            configuration. Only the keys that don't exist in the `config` will
220            be updated.
221    """
222
223    def __init__(self, defaults, patterns=None):
224        super().__init__(defaults)
225        if patterns is None:
226            patterns = _get_default_structured_pruning_patterns()
227        self.patterns = patterns
228
229    def make_config_from_model(
230        self,
231        model: nn.Module,
232        SUPPORTED_MODULES: Optional[Set[Type]] = None,
233    ) -> None:
234        if SUPPORTED_MODULES is None:
235            SUPPORTED_MODULES = _get_supported_structured_pruning_modules()
236        super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES)
237
238    def _prepare(self, *args, **kwargs) -> None:
239        r"""This function will attach the FakeStructuredSparsity parameterizations
240        and BiasHooks at the appropriate points in the model.
241        """
242        for config in self.groups:
243            module = config["module"]
244            tensor_name = config["tensor_name"]
245            parametrization = config.get("parametrization", FakeStructuredSparsity)
246            tensor = getattr(module, tensor_name)
247
248            mask = config.get(
249                "mask",
250                torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device),
251            )
252            self.state[config["tensor_fqn"]]["mask"] = mask
253            parametrize.register_parametrization(
254                module, tensor_name, parametrization(mask)
255            )
256
257            # if linear / conv, we add in bias hooks
258            if isinstance(module, (nn.Linear, nn.Conv2d)):
259                prune_bias = config.get("prune_bias", True)
260                if module.bias is not None:
261                    module.register_parameter(
262                        "_bias", nn.Parameter(module.bias.detach())
263                    )
264                    module.bias = None
265                    module.prune_bias = prune_bias
266
267                module.register_forward_hook(
268                    BiasHook(module.parametrizations.weight[0], prune_bias)
269                )
270
271    def prune(self) -> None:
272        r"""
273        This function will FX symbolically trace the model and then find instances of the patterns
274        defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ).
275
276        For each pattern, it will apply to corresponding conversion function, which will modify the output
277        and input size expected by the modules within the pattern
278        """
279
280        self.traced = symbolic_trace(self.model)
281        modules = dict(self.traced.named_modules())
282
283        # Right now we check for matches simply by iterating across all the patterns
284        # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup
285        for node in self.traced.graph.nodes:
286            for pattern, convert_fn in self.patterns.items():
287                matched = apply_match(modules, pattern, node, [])
288                if matched is None:
289                    continue
290
291                first_module = modules.get(node.target)
292                # check if first module exists and has appropriate parameterization, otherwise skip
293                if (
294                    first_module is not None
295                    and parametrize.is_parametrized(first_module)
296                    and module_contains_param(first_module, FakeStructuredSparsity)
297                ):
298                    convert_block = []
299                    for node in matched:
300                        if node.op == "call_module":
301                            convert_block.append(modules.get(node.target))
302                        elif node.op == "call_function":
303                            convert_block.append(node.target)
304                    convert_fn(*convert_block)
305
306        for module in self.traced.modules():
307            if module_contains_param(module, FakeStructuredSparsity):
308                raise Exception(  # noqa: TRY002
309                    f"Error: {module} still contains FakeStructuredSparsity parametrizations!"
310                )
311
312        self.traced.graph.lint()
313        self.traced.recompile()
314        return self.traced  # type: ignore[return-value]
315