1# mypy: allow-untyped-defs 2from itertools import chain 3from typing import Any, Dict, Optional, Type 4 5from torch import nn 6from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations 7 8 9__all__ = [ 10 "module_contains_param", 11 "swap_module", 12 "module_to_fqn", 13 "fqn_to_module", 14 "get_arg_info_from_tensor_fqn", 15 "FakeSparsity", 16] 17 18 19def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool: 20 if is_parametrized(module): 21 # see if any of the module tensors have a parametriztion attached that matches the one passed in 22 return any( 23 any(isinstance(param, parametrization) for param in param_list) 24 for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] 25 ) 26 return False 27 28 29def swap_module( 30 mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]] 31) -> nn.Module: 32 r"""Swaps the module using from_dense according to the mapping passed in. 33 Args: 34 mod: input module 35 mapping: a dictionary that maps from nn module to sparse nn module 36 Return: 37 The corresponding sparse module of `mod` according to mapping, created using from_dense 38 """ 39 if type_before_parametrizations(mod) in mapping: 40 sparse_mod = mapping[type_before_parametrizations(mod)] 41 42 # TODO Fix this typing, as Type[Module] has no attribute "from_dense" 43 new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] 44 45 # Preserve module's pre forward hooks. They'll be called on quantized input 46 for pre_hook_fn in mod._forward_pre_hooks.values(): 47 new_mod.register_forward_pre_hook(pre_hook_fn) 48 # Preserve module's post forward hooks except _observer_forward_hook 49 # After convert they'll work with quantized output 50 for hook_fn in mod._forward_hooks.values(): 51 new_mod.register_forward_hook(hook_fn) 52 53 # respect device affinity when swapping modules 54 devices = {p.device for p in chain(mod.parameters(), mod.buffers())} 55 assert ( 56 len(devices) <= 1 57 ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" 58 device = next(iter(devices)) if len(devices) > 0 else None 59 if device: 60 new_mod.to(device) 61 62 return new_mod 63 64 else: 65 return mod 66 67 68def module_to_fqn( 69 model: nn.Module, module: nn.Module, prefix: str = "" 70) -> Optional[str]: 71 """ 72 Returns the fqn for a module or None if module not a descendent of model. 73 """ 74 if module is model: 75 return "" 76 for name, child in model.named_children(): 77 fqn = module_to_fqn(child, module, ".") 78 if isinstance(fqn, str): 79 return prefix + name + fqn 80 return None 81 82 83def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: 84 """ 85 Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` 86 doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. 87 """ 88 if path != "": 89 for name in path.split("."): 90 model = getattr(model, name, None) 91 return model 92 93 94def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: 95 """ 96 Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name 97 """ 98 # string manip to split tensor_fqn into module_fqn and tensor_name 99 # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' 100 # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' 101 tensor_name = tensor_fqn.split(".")[-1] 102 module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] 103 104 module = fqn_to_module(model, module_fqn) 105 106 return { 107 "module_fqn": module_fqn, 108 "module": module, 109 "tensor_name": tensor_name, 110 "tensor_fqn": tensor_fqn, 111 } 112 113 114# Parametrizations 115class FakeSparsity(nn.Module): 116 r"""Parametrization for the weights. Should be attached to the 'weight' or 117 any other parameter that requires a mask applied to it. 118 119 Note:: 120 121 Once the mask is passed, the variable should not change the id. The 122 contents of the mask can change, but the mask reference itself should 123 not. 124 """ 125 126 def __init__(self, mask): 127 super().__init__() 128 self.register_buffer("mask", mask) 129 130 def forward(self, x): 131 assert self.mask.shape == x.shape 132 return self.mask * x 133 134 def state_dict(self, *args, **kwargs): 135 # We don't want to let the parametrizations to save the mask. 136 # That way we make sure that the linear module doesn't store the masks 137 # alongside their parametrizations. 138 return {} 139