1# mypy: allow-untyped-defs
2import logging
3
4from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import (
5    SUPPORTED_TYPES,
6)
7
8
9logger: logging.Logger = logging.getLogger(__name__)
10
11
12def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None):
13    """Attaches a data sparsifier to all the layers of the module.
14    Essentially, loop over all the weight parameters in the module and
15    attach it to the data sparsifier.
16    Note::
17        The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below)
18        before attaching to the sparsifier. This is because, the data
19        sparsifier uses a dummy model inside to store the weight parameters.
20    """
21    if config is None:
22        config = {}
23    for name, parameter in module.named_parameters():
24        if type(parameter) in SUPPORTED_TYPES:
25            valid_name = _get_valid_name(name)
26            # will be defaulted to default configs
27            data_sparsifier.add_data(
28                name=valid_name, data=parameter, **config.get(valid_name, {})
29            )
30
31
32def _get_valid_name(name):
33    return name.replace(".", "_")  # . is not allowed as a name
34
35
36def _log_sparsified_level(model, data_sparsifier) -> None:
37    # Show the level of sparsity AFTER step:
38    for name, parameter in model.named_parameters():
39        if type(parameter) not in SUPPORTED_TYPES:
40            continue
41        valid_name = _get_valid_name(name)
42        mask = data_sparsifier.get_mask(name=valid_name)
43        sparsity_level = 1.0 - mask.float().mean()
44        logger.info("Sparsity in layer %s = % .2%", name, sparsity_level)
45