xref: /aosp_15_r20/external/libopus/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31
32from .base_sparsifier import BaseSparsifier
33from .common import sparsify_matrix
34
35
36class LinearSparsifier(BaseSparsifier):
37    def __init__(self, task_list, start, stop, interval, exponent=3):
38        """ Sparsifier for torch.nn.GRUs
39
40            Parameters:
41            -----------
42            task_list : list
43                task_list contains a list of tuples (linear, params), where linear is an instance
44                of torch.nn.Linear and params is a tuple (density, [m, n]),
45                where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
46                sparsification is applied.
47
48            start : int
49                training step after which sparsification will be started.
50
51            stop : int
52                training step after which sparsification will be completed.
53
54            interval : int
55                sparsification interval for steps between start and stop. After stop sparsification will be
56                carried out after every call to GRUSparsifier.step()
57
58            exponent : float
59                Interpolation exponent for sparsification interval. In step i sparsification will be carried out
60                with density (alpha + target_density * (1 * alpha)), where
61                alpha = ((stop - i) / (start - stop)) ** exponent
62
63            Example:
64            --------
65            >>> import torch
66            >>> linear = torch.nn.Linear(8, 16)
67            >>> params = (0.2, [8, 4])
68            >>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
69            >>> for i in range(100):
70            ...         sparsifier.step()
71        """
72
73        super().__init__(task_list, start, stop, interval, exponent=3)
74
75        self.last_mask = None
76
77    def sparsify(self, alpha, verbose=False):
78        """ carries out sparsification step
79
80            Call this function after optimizer.step in your
81            training loop.
82
83            Parameters:
84            ----------
85            alpha : float
86                density interpolation parameter (1: dense, 0: target density)
87            verbose : bool
88                if true, densities are printed out
89
90            Returns:
91            --------
92            None
93
94        """
95
96        with torch.no_grad():
97            for linear, params in self.task_list:
98                if hasattr(linear, 'weight_v'):
99                    weight = linear.weight_v
100                else:
101                    weight = linear.weight
102                target_density, block_size = params
103                density = alpha + (1 - alpha) * target_density
104                weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
105
106                if self.last_mask is not None:
107                    if not torch.all(self.last_mask * new_mask == new_mask) and debug:
108                        print("weight resurrection in conv.weight")
109
110                self.last_mask = new_mask
111
112                if verbose:
113                    print(f"linear_sparsifier[{self.step_counter}]: {density=}")
114
115
116if __name__ == "__main__":
117    print("Testing sparsifier")
118
119    import torch
120    linear = torch.nn.Linear(8, 16)
121    params = (0.2, [4, 2])
122
123    sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
124
125    for i in range(100):
126            sparsifier.step(verbose=True)
127
128    print(linear.weight)
129