1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31 32 33from .base_sparsifier import BaseSparsifier 34from .common import sparsify_matrix, debug 35 36 37class ConvTranspose1dSparsifier(BaseSparsifier): 38 def __init__(self, task_list, start, stop, interval, exponent=3): 39 """ Sparsifier for torch.nn.GRUs 40 41 Parameters: 42 ----------- 43 task_list : list 44 task_list contains a list of tuples (conv1d, params), where conv1d is an instance 45 of torch.nn.Conv1d and params is a tuple (density, [m, n]), 46 where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which 47 sparsification is applied. 48 49 start : int 50 training step after which sparsification will be started. 51 52 stop : int 53 training step after which sparsification will be completed. 54 55 interval : int 56 sparsification interval for steps between start and stop. After stop sparsification will be 57 carried out after every call to GRUSparsifier.step() 58 59 exponent : float 60 Interpolation exponent for sparsification interval. In step i sparsification will be carried out 61 with density (alpha + target_density * (1 * alpha)), where 62 alpha = ((stop - i) / (start - stop)) ** exponent 63 64 Example: 65 -------- 66 >>> import torch 67 >>> conv = torch.nn.ConvTranspose1d(8, 16, 8) 68 >>> params = (0.2, [8, 4]) 69 >>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50) 70 >>> for i in range(100): 71 ... sparsifier.step() 72 """ 73 74 super().__init__(task_list, start, stop, interval, exponent=3) 75 76 self.last_mask = None 77 78 def sparsify(self, alpha, verbose=False): 79 """ carries out sparsification step 80 81 Call this function after optimizer.step in your 82 training loop. 83 84 Parameters: 85 ---------- 86 alpha : float 87 density interpolation parameter (1: dense, 0: target density) 88 verbose : bool 89 if true, densities are printed out 90 91 Returns: 92 -------- 93 None 94 95 """ 96 97 with torch.no_grad(): 98 for conv, params in self.task_list: 99 # reshape weight 100 if hasattr(conv, 'weight_v'): 101 weight = conv.weight_v 102 else: 103 weight = conv.weight 104 i, o, k = weight.shape 105 w = weight.permute(2, 1, 0).reshape(k * o, i) 106 target_density, block_size = params 107 density = alpha + (1 - alpha) * target_density 108 w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True) 109 w = w.reshape(k, o, i).permute(2, 1, 0) 110 weight[:] = w 111 112 if self.last_mask is not None: 113 if not torch.all(self.last_mask * new_mask == new_mask) and debug: 114 print("weight resurrection in conv.weight") 115 116 self.last_mask = new_mask 117 118 if verbose: 119 print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}") 120 121 122if __name__ == "__main__": 123 print("Testing sparsifier") 124 125 import torch 126 conv = torch.nn.ConvTranspose1d(8, 16, 4, 4) 127 params = (0.2, [8, 4]) 128 129 sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5) 130 131 for i in range(100): 132 sparsifier.step(verbose=True) 133 134 print(conv.weight) 135