xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/_experimental/pruner/prune_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Collection of conversion functions for linear / conv2d structured pruning
4Also contains utilities for bias propagation
5"""
6from typing import Callable, cast, List, Optional, Tuple
7
8import torch
9from torch import nn, Tensor
10from torch.nn.utils import parametrize
11from torch.nn.utils.parametrize import ParametrizationList
12
13from .parametrization import BiasHook, FakeStructuredSparsity
14
15
16# BIAS PROPAGATION
17def _remove_bias_handles(module: nn.Module) -> None:
18    if hasattr(module, "_forward_hooks"):
19        bias_hooks: List[int] = []
20        for key, hook in module._forward_hooks.items():
21            if isinstance(hook, BiasHook):
22                bias_hooks.append(key)
23
24        for key in bias_hooks:
25            del module._forward_hooks[key]
26
27
28def _get_adjusted_next_layer_bias(
29    next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor
30) -> nn.Parameter:
31    r"""Returns new adjusted bias for the second supported module"""
32    if parametrize.is_parametrized(next_layer):
33        # need to access original weight
34        parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations)
35        weight_parameterizations = cast(
36            ParametrizationList, parametrization_dict.weight
37        )
38        next_weight = weight_parameterizations.original
39    else:
40        next_weight = cast(Tensor, next_layer.weight)
41
42    scaling_weight = next_weight[:, ~mask]
43    if isinstance(next_layer, nn.Conv2d):  # checking for Conv2d
44        # Propagating first layer pruned biases and calculating the new second layer bias
45        # involves more steps since the Conv2d scaling weight has extra dimensions,
46        # so adding bias involves broadcasting, logically:
47        # for each channel k in range(oC):
48        #     scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T)
49        #     new_next_bias[k] = old_next_bias[k] + scaled_biases
50        scaling_product = torch.matmul(
51            pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2)
52        )
53        sum_range = list(range(len(scaling_product.shape)))[
54            1:
55        ]  # all but the first dimension
56        scaled_biases = torch.sum(scaling_product, sum_range)
57    elif isinstance(next_layer, nn.Linear):  # Linear
58        scaled_biases = torch.matmul(
59            pruned_biases, torch.transpose(scaling_weight, 0, 1)
60        )  # recall b2_new = b1 @ w2.T + b2
61    else:
62        raise NotImplementedError(f"Type {type(next_layer)} not supported yet.")
63
64    if (
65        parametrize.is_parametrized(next_layer)
66        and getattr(next_layer, "_bias", None) is not None
67    ):  # next_layer is parametrized & has original bias ._bias
68        adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias)
69    elif (
70        not parametrize.is_parametrized(next_layer) and next_layer.bias is not None
71    ):  # next_layer not parametrized & has .bias
72        adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias)
73    else:  # next_layer has no bias
74        adjusted_bias = nn.Parameter(scaled_biases)
75    return adjusted_bias
76
77
78def _prune_module_bias(module: nn.Module, mask: Tensor) -> None:
79    r"""Applies mask to given modules bias"""
80    # prune bias along with weights, discard pruned indices of bias
81    original_bias = cast(Tensor, getattr(module, "_bias", module.bias))
82    if original_bias is not None:
83        module.bias = nn.Parameter(original_bias[mask])
84
85    #  remove _bias parameter
86    if hasattr(module, "_bias"):
87        delattr(module, "_bias")
88
89
90def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
91    r"""
92    In the case that we need to propagate biases, this function will return the biases we need
93    """
94    # set current module bias
95    if module.bias is not None:
96        module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
97    elif getattr(module, "_bias", None) is not None:
98        module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
99
100    # get pruned biases to propagate to subsequent layer
101    if getattr(module, "_bias", None) is not None:
102        pruned_biases = cast(Tensor, module._bias)[~mask]
103    else:
104        pruned_biases = None
105
106    if hasattr(module, "_bias"):
107        delattr(module, "_bias")
108
109    return pruned_biases
110
111
112# LINEAR
113def _prune_linear_helper(linear: nn.Linear) -> Tensor:
114    # expects linear to be a parameterized linear module
115    parametrization_dict = cast(nn.ModuleDict, linear.parametrizations)
116    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
117    for p in weight_parameterizations:
118        if isinstance(p, FakeStructuredSparsity):
119            mask = cast(Tensor, p.mask)
120
121    with torch.no_grad():
122        parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True)
123        linear.weight = nn.Parameter(linear.weight[mask])  # type: ignore[possibly-undefined]
124    linear.out_features = linear.weight.shape[0]
125    _remove_bias_handles(linear)
126
127    return mask
128
129
130def prune_linear(linear: nn.Linear) -> None:
131    mask = _prune_linear_helper(linear)
132    if getattr(linear, "prune_bias", False):
133        _prune_module_bias(linear, mask)
134
135
136def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None:
137    prune_linear_activation_linear(linear1, None, linear2)
138
139
140def prune_linear_activation_linear(
141    linear1: nn.Linear,
142    activation: Optional[Callable[[Tensor], Tensor]],
143    linear2: nn.Linear,
144):
145    mask = _prune_linear_helper(linear1)
146    if getattr(linear1, "prune_bias", False):
147        _prune_module_bias(linear1, mask)
148    else:
149        pruned_biases = _propagate_module_bias(linear1, mask)
150        if pruned_biases is not None:
151            if activation:
152                pruned_biases = activation(pruned_biases)
153            linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask)
154
155    with torch.no_grad():
156        if parametrize.is_parametrized(linear2):
157            parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations)
158            weight_parameterizations = cast(
159                ParametrizationList, parametrization_dict.weight
160            )
161
162            weight_parameterizations.original = nn.Parameter(
163                weight_parameterizations.original[:, mask]
164            )
165            linear2.in_features = weight_parameterizations.original.shape[1]
166        else:
167            linear2.weight = nn.Parameter(linear2.weight[:, mask])
168            linear2.in_features = linear2.weight.shape[1]
169
170
171# CONV2D
172def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
173    parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations)
174    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
175    for p in weight_parameterizations:
176        if isinstance(p, FakeStructuredSparsity):
177            mask = cast(Tensor, p.mask)
178
179    with torch.no_grad():
180        parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True)
181        conv2d.weight = nn.Parameter(conv2d.weight[mask])  # type: ignore[possibly-undefined]
182    conv2d.out_channels = conv2d.weight.shape[0]
183
184    _remove_bias_handles(conv2d)
185    return mask
186
187
188def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
189    parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations)
190    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
191    for p in weight_parameterizations:
192        if isinstance(p, FakeStructuredSparsity):
193            mask = cast(Tensor, p.mask)
194
195    with torch.no_grad():
196        parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True)
197
198    if getattr(conv2d_1, "_bias", None) is not None:
199        if (
200            conv2d_1.bias is not None
201        ):  # conv2d_1 has original bias and bias propagated from previous layer
202            new_bias = torch.zeros(conv2d_1.bias.shape)
203            new_bias[mask] = conv2d_1.bias[mask]  # type: ignore[possibly-undefined]
204            # adjusted bias that to keep in conv2d_1
205            new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
206            # pruned biases that are kept instead of propagated
207            conv2d_1.bias = nn.Parameter(new_bias)
208        else:  # conv2d_1 has only original bias
209            conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias))
210    else:
211        # no original bias, only propagated bias
212        if (
213            conv2d_1.bias is not None
214        ):  # conv2d_1 has bias propagated from previous layer
215            conv2d_1.bias.data[~mask] = 0  # type: ignore[possibly-undefined]
216
217    if hasattr(conv2d_1, "_bias"):
218        delattr(conv2d_1, "_bias")
219
220
221def prune_conv2d(conv2d: nn.Conv2d) -> None:
222    mask = _prune_conv2d_helper(conv2d)
223    if getattr(conv2d, "prune_bias", False):
224        _prune_module_bias(conv2d, mask)
225
226
227def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None:
228    prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2)
229
230
231def prune_conv2d_activation_conv2d(
232    conv2d_1: nn.Conv2d,
233    activation: Optional[Callable[[Tensor], Tensor]],
234    conv2d_2: nn.Conv2d,
235):
236    r"""
237    Fusion Pattern for conv2d -> some activation module / function -> conv2d layers
238    """
239    parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations)
240    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
241    for p in weight_parameterizations:
242        if isinstance(p, FakeStructuredSparsity):
243            mask = cast(Tensor, p.mask)
244
245    prune_bias = getattr(conv2d_1, "prune_bias", False)
246    if (
247        hasattr(conv2d_2, "padding")
248        and cast(Tuple[int], conv2d_2.padding) > (0, 0)
249        and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None)
250    ):
251        prune_conv2d_padded(conv2d_1)
252    else:
253        mask = _prune_conv2d_helper(conv2d_1)
254        if prune_bias:
255            _prune_module_bias(conv2d_1, mask)
256        else:
257            pruned_biases = _propagate_module_bias(conv2d_1, mask)
258            if pruned_biases is not None:
259                if activation:
260                    pruned_biases = activation(pruned_biases)
261                conv2d_2.bias = _get_adjusted_next_layer_bias(
262                    conv2d_2, pruned_biases, mask
263                )
264
265        if (
266            not (
267                hasattr(conv2d_2, "padding")
268                and cast(Tuple[int], conv2d_2.padding) > (0, 0)
269            )
270            or conv2d_1.bias is None
271        ):
272            with torch.no_grad():
273                if parametrize.is_parametrized(conv2d_2):
274                    parametrization_dict = cast(
275                        nn.ModuleDict, conv2d_2.parametrizations
276                    )
277                    weight_parameterizations = cast(
278                        ParametrizationList, parametrization_dict.weight
279                    )
280                    weight_parameterizations.original = nn.Parameter(
281                        weight_parameterizations.original[:, mask]
282                    )
283                    conv2d_2.in_channels = weight_parameterizations.original.shape[1]
284                else:
285                    conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask])
286                    conv2d_2.in_channels = conv2d_2.weight.shape[1]
287
288
289def prune_conv2d_pool_activation_conv2d(
290    c1: nn.Conv2d,
291    pool: nn.Module,
292    activation: Optional[Callable[[Tensor], Tensor]],
293    c2: nn.Conv2d,
294) -> None:
295    prune_conv2d_activation_conv2d(c1, activation, c2)
296
297
298def prune_conv2d_activation_pool_conv2d(
299    c1: nn.Conv2d,
300    activation: Optional[Callable[[Tensor], Tensor]],
301    pool: nn.Module,
302    c2: nn.Conv2d,
303) -> None:
304    prune_conv2d_activation_conv2d(c1, activation, c2)
305
306
307def prune_conv2d_pool_flatten_linear(
308    conv2d: nn.Conv2d,
309    pool: nn.Module,
310    flatten: Optional[Callable[[Tensor], Tensor]],
311    linear: nn.Linear,
312) -> None:
313    mask = _prune_conv2d_helper(conv2d)
314
315    # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer.
316    # we determine the flattening scale (h * w), and readjust `first_pruned_indices`
317    # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`,
318    # and `pruned_biases` (repeat each bias by h * w).
319    if parametrize.is_parametrized(linear):
320        parametrization_dict = cast(nn.ModuleDict, linear.parametrizations)
321        weight_parameterizations = cast(
322            ParametrizationList, parametrization_dict.weight
323        )
324        linear_ic = weight_parameterizations.original.shape[1]
325    else:
326        linear_ic = linear.weight.shape[1]
327
328    conv2d_oc = len(mask)
329    assert (
330        linear_ic % conv2d_oc == 0
331    ), f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported"
332
333    flatten_scale = linear_ic // conv2d_oc
334    flattened_mask = torch.tensor(
335        [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device
336    ).flatten()
337
338    if getattr(conv2d, "prune_bias", False):
339        _prune_module_bias(conv2d, mask)
340    else:
341        pruned_biases = cast(Tensor, _propagate_module_bias(conv2d, mask))
342        flattened_pruned_biases = torch.tensor(
343            [[bias] * flatten_scale for bias in pruned_biases], device=mask.device
344        ).flatten()
345        linear.bias = _get_adjusted_next_layer_bias(
346            linear, flattened_pruned_biases, flattened_mask
347        )
348
349    with torch.no_grad():
350        if parametrize.is_parametrized(linear):
351            parametrization_dict = cast(nn.ModuleDict, linear.parametrizations)
352            weight_parameterizations = cast(
353                ParametrizationList, parametrization_dict.weight
354            )
355            weight_parameterizations.original = nn.Parameter(
356                weight_parameterizations.original[:, flattened_mask]
357            )
358            linear.in_features = weight_parameterizations.original.shape[1]
359        else:
360            linear.weight = nn.Parameter(linear.weight[:, flattened_mask])
361            linear.in_features = linear.weight.shape[1]
362
363
364def prune_lstm_output_linear(
365    lstm: nn.LSTM, getitem: Callable, linear: nn.Linear
366) -> None:
367    prune_lstm_output_layernorm_linear(lstm, getitem, None, linear)
368
369
370def prune_lstm_output_layernorm_linear(
371    lstm: nn.LSTM,
372    getitem: Callable,
373    layernorm: Optional[nn.LayerNorm],
374    linear: nn.Linear,
375) -> None:
376    for i in range(lstm.num_layers):
377        if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"):
378            parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations)
379            weight_parameterizations = cast(
380                ParametrizationList, parametrization_dict[f"weight_ih_l{i}"]
381            )
382            mask = weight_parameterizations[0].mask
383
384            with torch.no_grad():
385                parametrize.remove_parametrizations(
386                    lstm, f"weight_ih_l{i}", leave_parametrized=True
387                )
388                setattr(
389                    lstm,
390                    f"weight_ih_l{i}",
391                    nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]),
392                )
393                setattr(
394                    lstm,
395                    f"bias_ih_l{i}",
396                    nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]),
397                )
398
399        if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"):
400            parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations)
401            weight_parameterizations = cast(
402                ParametrizationList, parametrization_dict[f"weight_hh_l{i}"]
403            )
404            mask = weight_parameterizations[0].mask
405
406            with torch.no_grad():
407                parametrize.remove_parametrizations(
408                    lstm, f"weight_hh_l{i}", leave_parametrized=True
409                )
410                # splitting out hidden-hidden masks
411                W_hi, W_hf, W_hg, W_ho = torch.split(
412                    getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size
413                )
414                M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size)
415
416                # resize each individual weight separately
417                W_hi = W_hi[M_hi][:, M_hi]
418                W_hf = W_hf[M_hf][:, M_hf]
419                W_hg = W_hg[M_hg][:, M_hg]
420                W_ho = W_ho[M_ho][:, M_ho]
421
422                # concat, use this as new weight
423                new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho))
424                setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight))
425                setattr(
426                    lstm,
427                    f"bias_hh_l{i}",
428                    nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]),
429                )
430
431            # If this is the final layer, then we need to prune linear layer columns
432            if i + 1 == lstm.num_layers:
433                lstm.hidden_size = int(M_hi.sum())
434                with torch.no_grad():
435                    if parametrize.is_parametrized(linear):
436                        parametrization_dict = cast(
437                            nn.ModuleDict, linear.parametrizations
438                        )
439                        weight_parameterizations = cast(
440                            ParametrizationList, parametrization_dict.weight
441                        )
442
443                        weight_parameterizations.original = nn.Parameter(
444                            weight_parameterizations.original[:, M_ho]
445                        )
446                        linear.in_features = weight_parameterizations.original.shape[1]
447                    else:
448                        linear.weight = nn.Parameter(linear.weight[:, M_ho])
449                        linear.in_features = linear.weight.shape[1]
450
451                    # if layernorm module, prune weight and bias
452                    if layernorm is not None:
453                        layernorm.normalized_shape = (linear.in_features,)
454                        layernorm.weight = nn.Parameter(layernorm.weight[M_ho])
455                        layernorm.bias = nn.Parameter(layernorm.bias[M_ho])
456
457            # otherwise need to prune the columns of the input of the next LSTM layer
458            else:
459                with torch.no_grad():
460                    if parametrize.is_parametrized(lstm, f"weight_ih_l{i + 1}"):
461                        parametrization_dict = cast(
462                            nn.ModuleDict, lstm.parametrizations
463                        )
464                        weight_parameterizations = cast(
465                            ParametrizationList,
466                            getattr(parametrization_dict, f"weight_ih_l{i + 1}"),
467                        )
468
469                        weight_parameterizations.original = nn.Parameter(
470                            weight_parameterizations.original[:, M_ho]
471                        )
472                    else:
473                        next_layer_weight = getattr(lstm, f"weight_ih_l{i + 1}")
474                        setattr(
475                            lstm,
476                            f"weight_ih_l{i + 1}",
477                            nn.Parameter(next_layer_weight[:, M_ho]),
478                        )
479