xref: /aosp_15_r20/external/libopus/dnn/torch/dnntools/dnntools/sparsification/utils.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import torch
2
3from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
4
5def mark_for_sparsification(module, params):
6    setattr(module, 'sparsify', True)
7    setattr(module, 'sparsification_params', params)
8    return module
9
10def create_sparsifier(module, start, stop, interval):
11    sparsifier_list = []
12    for m in module.modules():
13        if hasattr(m, 'sparsify'):
14            if isinstance(m, torch.nn.GRU):
15                sparsifier_list.append(
16                    GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
17                )
18            elif isinstance(m, torch.nn.Linear):
19                sparsifier_list.append(
20                    LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
21                )
22            elif isinstance(m, torch.nn.Conv1d):
23                sparsifier_list.append(
24                    Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
25                )
26            elif isinstance(m, torch.nn.ConvTranspose1d):
27                sparsifier_list.append(
28                    ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
29                )
30            else:
31                print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
32
33    def sparsify(verbose=False):
34        for sparsifier in sparsifier_list:
35            sparsifier.step(verbose)
36
37    return sparsify
38
39
40def count_parameters(model, verbose=False):
41    total = 0
42    for name, p in model.named_parameters():
43        count = torch.ones_like(p).sum().item()
44
45        if verbose:
46            print(f"{name}: {count} parameters")
47
48        total += count
49
50    return total
51
52def estimate_nonzero_parameters(module):
53    num_zero_parameters = 0
54    if hasattr(module, 'sparsify'):
55        params = module.sparsification_params
56        if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
57            num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
58        elif isinstance(module, torch.nn.GRU):
59            num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
60            num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
61        elif isinstance(module, torch.nn.Linear):
62            num_zero_parameters = module.in_features * module.out_features * params[0]
63        else:
64            raise ValueError(f'unknown sparsification method for module of type {type(module)}')
65