1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31 32def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False): 33 """ sparsifies matrix with specified block size 34 35 Parameters: 36 ----------- 37 matrix : torch.tensor 38 matrix to sparsify 39 density : int 40 target density 41 block_size : [int, int] 42 block size dimensions 43 keep_diagonal : bool 44 If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False 45 """ 46 47 m, n = matrix.shape 48 m1, n1 = block_size 49 50 if m % m1 or n % n1: 51 raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}") 52 53 # extract diagonal if keep_diagonal = True 54 if keep_diagonal: 55 if m != n: 56 raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True") 57 58 to_spare = torch.diag(torch.diag(matrix)) 59 matrix = matrix - to_spare 60 else: 61 to_spare = torch.zeros_like(matrix) 62 63 # calculate energy in sub-blocks 64 x = torch.reshape(matrix, (m // m1, m1, n // n1, n1)) 65 x = x ** 2 66 block_energies = torch.sum(torch.sum(x, dim=3), dim=1) 67 68 number_of_blocks = (m * n) // (m1 * n1) 69 number_of_survivors = round(number_of_blocks * density) 70 71 # masking threshold 72 if number_of_survivors == 0: 73 threshold = 0 74 else: 75 threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors] 76 77 # create mask 78 mask = torch.ones_like(block_energies) 79 mask[block_energies < threshold] = 0 80 mask = torch.repeat_interleave(mask, m1, dim=0) 81 mask = torch.repeat_interleave(mask, n1, dim=1) 82 83 # perform masking 84 masked_matrix = mask * matrix + to_spare 85 86 if return_mask: 87 return masked_matrix, mask 88 else: 89 return masked_matrix 90 91def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False): 92 input_size = gru.input_size 93 hidden_size = gru.hidden_size 94 flops = 0 95 96 input_density = ( 97 sparsification_dict.get('W_ir', [1])[0] 98 + sparsification_dict.get('W_in', [1])[0] 99 + sparsification_dict.get('W_iz', [1])[0] 100 ) / 3 101 102 recurrent_density = ( 103 sparsification_dict.get('W_hr', [1])[0] 104 + sparsification_dict.get('W_hn', [1])[0] 105 + sparsification_dict.get('W_hz', [1])[0] 106 ) / 3 107 108 # input matrix vector multiplications 109 if not drop_input: 110 flops += 2 * 3 * input_size * hidden_size * input_density 111 112 # recurrent matrix vector multiplications 113 flops += 2 * 3 * hidden_size * hidden_size * recurrent_density 114 115 # biases 116 flops += 6 * hidden_size 117 118 # activations estimated by 10 flops per activation 119 flops += 30 * hidden_size 120 121 return flops 122