xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from . import base_sparsifier
5
6
7class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier):
8    r"""Nearly Diagonal Sparsifier
9
10    This sparsifier creates a nearly diagonal mask to be applied to the weight matrix.
11    Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero.
12    An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively.
13    1 1 0 0       1 1 1 0
14    1 1 1 0       1 1 1 1
15    0 1 1 1       1 1 1 1
16    0 0 1 1       0 1 1 1
17    Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated
18
19    This sparsifier is controlled by one variable:
20    1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal.
21        Currently - supports only odd number
22
23    Note:
24        This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix
25        feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy
26
27    Args:
28        nearliness: The degree of nearliness (default = 1)
29
30    """
31
32    def __init__(self, nearliness: int = 1):
33        defaults = {"nearliness": nearliness}
34        super().__init__(defaults=defaults)
35
36    def update_mask(self, module, tensor_name, nearliness, **kwargs):
37        mask = getattr(module.parametrizations, tensor_name)[0].mask
38        mask.data = torch.zeros_like(mask)
39        if nearliness <= 0:
40            return
41
42        tensor = getattr(module, tensor_name)
43        height, width = tensor.shape
44
45        if nearliness % 2 == 0:
46            raise ValueError("nearliness can only be an odd number")
47        dist_to_diagonal = nearliness // 2
48        # check
49        if dist_to_diagonal >= min(height, width):
50            raise ValueError(
51                "nearliness cannot be larger than the dimensions of tensor."
52            )
53
54        for row in range(0, height):
55            # Bounds of entries that needs to be set to 1
56            low = max(0, row - dist_to_diagonal)
57            high = min(width, row + dist_to_diagonal + 1)
58            mask[row, low:high].fill_(1)
59