1import torch 2 3from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier 4 5def mark_for_sparsification(module, params): 6 setattr(module, 'sparsify', True) 7 setattr(module, 'sparsification_params', params) 8 return module 9 10def create_sparsifier(module, start, stop, interval): 11 sparsifier_list = [] 12 for m in module.modules(): 13 if hasattr(m, 'sparsify'): 14 if isinstance(m, torch.nn.GRU): 15 sparsifier_list.append( 16 GRUSparsifier([(m, m.sparsification_params)], start, stop, interval) 17 ) 18 elif isinstance(m, torch.nn.Linear): 19 sparsifier_list.append( 20 LinearSparsifier([(m, m.sparsification_params)], start, stop, interval) 21 ) 22 elif isinstance(m, torch.nn.Conv1d): 23 sparsifier_list.append( 24 Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval) 25 ) 26 elif isinstance(m, torch.nn.ConvTranspose1d): 27 sparsifier_list.append( 28 ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval) 29 ) 30 else: 31 print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.") 32 33 def sparsify(verbose=False): 34 for sparsifier in sparsifier_list: 35 sparsifier.step(verbose) 36 37 return sparsify 38 39 40def count_parameters(model, verbose=False): 41 total = 0 42 for name, p in model.named_parameters(): 43 count = torch.ones_like(p).sum().item() 44 45 if verbose: 46 print(f"{name}: {count} parameters") 47 48 total += count 49 50 return total 51 52def estimate_nonzero_parameters(module): 53 num_zero_parameters = 0 54 if hasattr(module, 'sparsify'): 55 params = module.sparsification_params 56 if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d): 57 num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0]) 58 elif isinstance(module, torch.nn.GRU): 59 num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0]) 60 num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0]) 61 elif isinstance(module, torch.nn.Linear): 62 num_zero_parameters = module.in_features * module.out_features * params[0] 63 else: 64 raise ValueError(f'unknown sparsification method for module of type {type(module)}') 65