1# mypy: allow-untyped-defs 2import logging 3 4from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import ( 5 SUPPORTED_TYPES, 6) 7 8 9logger: logging.Logger = logging.getLogger(__name__) 10 11 12def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None): 13 """Attaches a data sparsifier to all the layers of the module. 14 Essentially, loop over all the weight parameters in the module and 15 attach it to the data sparsifier. 16 Note:: 17 The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below) 18 before attaching to the sparsifier. This is because, the data 19 sparsifier uses a dummy model inside to store the weight parameters. 20 """ 21 if config is None: 22 config = {} 23 for name, parameter in module.named_parameters(): 24 if type(parameter) in SUPPORTED_TYPES: 25 valid_name = _get_valid_name(name) 26 # will be defaulted to default configs 27 data_sparsifier.add_data( 28 name=valid_name, data=parameter, **config.get(valid_name, {}) 29 ) 30 31 32def _get_valid_name(name): 33 return name.replace(".", "_") # . is not allowed as a name 34 35 36def _log_sparsified_level(model, data_sparsifier) -> None: 37 # Show the level of sparsity AFTER step: 38 for name, parameter in model.named_parameters(): 39 if type(parameter) not in SUPPORTED_TYPES: 40 continue 41 valid_name = _get_valid_name(name) 42 mask = data_sparsifier.get_mask(name=valid_name) 43 sparsity_level = 1.0 - mask.float().mean() 44 logger.info("Sparsity in layer %s = % .2%", name, sparsity_level) 45