1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31 32from .base_sparsifier import BaseSparsifier 33from .common import sparsify_matrix 34 35 36class LinearSparsifier(BaseSparsifier): 37 def __init__(self, task_list, start, stop, interval, exponent=3): 38 """ Sparsifier for torch.nn.GRUs 39 40 Parameters: 41 ----------- 42 task_list : list 43 task_list contains a list of tuples (linear, params), where linear is an instance 44 of torch.nn.Linear and params is a tuple (density, [m, n]), 45 where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which 46 sparsification is applied. 47 48 start : int 49 training step after which sparsification will be started. 50 51 stop : int 52 training step after which sparsification will be completed. 53 54 interval : int 55 sparsification interval for steps between start and stop. After stop sparsification will be 56 carried out after every call to GRUSparsifier.step() 57 58 exponent : float 59 Interpolation exponent for sparsification interval. In step i sparsification will be carried out 60 with density (alpha + target_density * (1 * alpha)), where 61 alpha = ((stop - i) / (start - stop)) ** exponent 62 63 Example: 64 -------- 65 >>> import torch 66 >>> linear = torch.nn.Linear(8, 16) 67 >>> params = (0.2, [8, 4]) 68 >>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50) 69 >>> for i in range(100): 70 ... sparsifier.step() 71 """ 72 73 super().__init__(task_list, start, stop, interval, exponent=3) 74 75 self.last_mask = None 76 77 def sparsify(self, alpha, verbose=False): 78 """ carries out sparsification step 79 80 Call this function after optimizer.step in your 81 training loop. 82 83 Parameters: 84 ---------- 85 alpha : float 86 density interpolation parameter (1: dense, 0: target density) 87 verbose : bool 88 if true, densities are printed out 89 90 Returns: 91 -------- 92 None 93 94 """ 95 96 with torch.no_grad(): 97 for linear, params in self.task_list: 98 if hasattr(linear, 'weight_v'): 99 weight = linear.weight_v 100 else: 101 weight = linear.weight 102 target_density, block_size = params 103 density = alpha + (1 - alpha) * target_density 104 weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True) 105 106 if self.last_mask is not None: 107 if not torch.all(self.last_mask * new_mask == new_mask) and debug: 108 print("weight resurrection in conv.weight") 109 110 self.last_mask = new_mask 111 112 if verbose: 113 print(f"linear_sparsifier[{self.step_counter}]: {density=}") 114 115 116if __name__ == "__main__": 117 print("Testing sparsifier") 118 119 import torch 120 linear = torch.nn.Linear(8, 16) 121 params = (0.2, [4, 2]) 122 123 sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5) 124 125 for i in range(100): 126 sparsifier.step(verbose=True) 127 128 print(linear.weight) 129