1# mypy: allow-untyped-defs 2import torch 3 4from . import base_sparsifier 5 6 7class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): 8 r"""Nearly Diagonal Sparsifier 9 10 This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. 11 Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. 12 An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. 13 1 1 0 0 1 1 1 0 14 1 1 1 0 1 1 1 1 15 0 1 1 1 1 1 1 1 16 0 0 1 1 0 1 1 1 17 Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated 18 19 This sparsifier is controlled by one variable: 20 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. 21 Currently - supports only odd number 22 23 Note: 24 This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix 25 feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy 26 27 Args: 28 nearliness: The degree of nearliness (default = 1) 29 30 """ 31 32 def __init__(self, nearliness: int = 1): 33 defaults = {"nearliness": nearliness} 34 super().__init__(defaults=defaults) 35 36 def update_mask(self, module, tensor_name, nearliness, **kwargs): 37 mask = getattr(module.parametrizations, tensor_name)[0].mask 38 mask.data = torch.zeros_like(mask) 39 if nearliness <= 0: 40 return 41 42 tensor = getattr(module, tensor_name) 43 height, width = tensor.shape 44 45 if nearliness % 2 == 0: 46 raise ValueError("nearliness can only be an odd number") 47 dist_to_diagonal = nearliness // 2 48 # check 49 if dist_to_diagonal >= min(height, width): 50 raise ValueError( 51 "nearliness cannot be larger than the dimensions of tensor." 52 ) 53 54 for row in range(0, height): 55 # Bounds of entries that needs to be set to 1 56 low = max(0, row - dist_to_diagonal) 57 high = min(width, row + dist_to_diagonal + 1) 58 mask[row, low:high].fill_(1) 59