1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31
32
33from .base_sparsifier import BaseSparsifier
34from .common import sparsify_matrix, debug
35
36
37class ConvTranspose1dSparsifier(BaseSparsifier):
38    def __init__(self, task_list, start, stop, interval, exponent=3):
39        """ Sparsifier for torch.nn.GRUs
40
41            Parameters:
42            -----------
43            task_list : list
44                task_list contains a list of tuples (conv1d, params), where conv1d is an instance
45                of torch.nn.Conv1d and params is a tuple (density, [m, n]),
46                where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
47                sparsification is applied.
48
49            start : int
50                training step after which sparsification will be started.
51
52            stop : int
53                training step after which sparsification will be completed.
54
55            interval : int
56                sparsification interval for steps between start and stop. After stop sparsification will be
57                carried out after every call to GRUSparsifier.step()
58
59            exponent : float
60                Interpolation exponent for sparsification interval. In step i sparsification will be carried out
61                with density (alpha + target_density * (1 * alpha)), where
62                alpha = ((stop - i) / (start - stop)) ** exponent
63
64            Example:
65            --------
66            >>> import torch
67            >>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
68            >>> params = (0.2, [8, 4])
69            >>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
70            >>> for i in range(100):
71            ...         sparsifier.step()
72        """
73
74        super().__init__(task_list, start, stop, interval, exponent=3)
75
76        self.last_mask = None
77
78    def sparsify(self, alpha, verbose=False):
79        """ carries out sparsification step
80
81            Call this function after optimizer.step in your
82            training loop.
83
84            Parameters:
85            ----------
86            alpha : float
87                density interpolation parameter (1: dense, 0: target density)
88            verbose : bool
89                if true, densities are printed out
90
91            Returns:
92            --------
93            None
94
95        """
96
97        with torch.no_grad():
98            for conv, params in self.task_list:
99                # reshape weight
100                if hasattr(conv, 'weight_v'):
101                    weight = conv.weight_v
102                else:
103                    weight = conv.weight
104                i, o, k = weight.shape
105                w = weight.permute(2, 1, 0).reshape(k * o, i)
106                target_density, block_size = params
107                density = alpha + (1 - alpha) * target_density
108                w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
109                w = w.reshape(k, o, i).permute(2, 1, 0)
110                weight[:] = w
111
112                if self.last_mask is not None:
113                    if not torch.all(self.last_mask * new_mask == new_mask) and debug:
114                        print("weight resurrection in conv.weight")
115
116                self.last_mask = new_mask
117
118                if verbose:
119                    print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
120
121
122if __name__ == "__main__":
123    print("Testing sparsifier")
124
125    import torch
126    conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
127    params = (0.2, [8, 4])
128
129    sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
130
131    for i in range(100):
132            sparsifier.step(verbose=True)
133
134    print(conv.weight)
135