xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/sparsifier/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from itertools import chain
3from typing import Any, Dict, Optional, Type
4
5from torch import nn
6from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations
7
8
9__all__ = [
10    "module_contains_param",
11    "swap_module",
12    "module_to_fqn",
13    "fqn_to_module",
14    "get_arg_info_from_tensor_fqn",
15    "FakeSparsity",
16]
17
18
19def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool:
20    if is_parametrized(module):
21        # see if any of the module tensors have a parametriztion attached that matches the one passed in
22        return any(
23            any(isinstance(param, parametrization) for param in param_list)
24            for key, param_list in module.parametrizations.items()  # type: ignore[union-attr,operator]
25        )
26    return False
27
28
29def swap_module(
30    mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]]
31) -> nn.Module:
32    r"""Swaps the module using from_dense according to the mapping passed in.
33    Args:
34        mod: input module
35        mapping: a dictionary that maps from nn module to sparse nn module
36    Return:
37        The corresponding sparse module of `mod` according to mapping, created using from_dense
38    """
39    if type_before_parametrizations(mod) in mapping:
40        sparse_mod = mapping[type_before_parametrizations(mod)]
41
42        # TODO Fix this typing, as Type[Module] has no attribute "from_dense"
43        new_mod = sparse_mod.from_dense(mod)  # type: ignore[attr-defined]
44
45        # Preserve module's pre forward hooks. They'll be called on quantized input
46        for pre_hook_fn in mod._forward_pre_hooks.values():
47            new_mod.register_forward_pre_hook(pre_hook_fn)
48        # Preserve module's post forward hooks except _observer_forward_hook
49        # After convert they'll work with quantized output
50        for hook_fn in mod._forward_hooks.values():
51            new_mod.register_forward_hook(hook_fn)
52
53        # respect device affinity when swapping modules
54        devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
55        assert (
56            len(devices) <= 1
57        ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
58        device = next(iter(devices)) if len(devices) > 0 else None
59        if device:
60            new_mod.to(device)
61
62        return new_mod
63
64    else:
65        return mod
66
67
68def module_to_fqn(
69    model: nn.Module, module: nn.Module, prefix: str = ""
70) -> Optional[str]:
71    """
72    Returns the fqn for a module or None if module not a descendent of model.
73    """
74    if module is model:
75        return ""
76    for name, child in model.named_children():
77        fqn = module_to_fqn(child, module, ".")
78        if isinstance(fqn, str):
79            return prefix + name + fqn
80    return None
81
82
83def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]:
84    """
85    Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
86    doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
87    """
88    if path != "":
89        for name in path.split("."):
90            model = getattr(model, name, None)
91    return model
92
93
94def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]:
95    """
96    Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
97    """
98    # string manip to split tensor_fqn into module_fqn and tensor_name
99    # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
100    # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
101    tensor_name = tensor_fqn.split(".")[-1]
102    module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
103
104    module = fqn_to_module(model, module_fqn)
105
106    return {
107        "module_fqn": module_fqn,
108        "module": module,
109        "tensor_name": tensor_name,
110        "tensor_fqn": tensor_fqn,
111    }
112
113
114# Parametrizations
115class FakeSparsity(nn.Module):
116    r"""Parametrization for the weights. Should be attached to the 'weight' or
117    any other parameter that requires a mask applied to it.
118
119    Note::
120
121        Once the mask is passed, the variable should not change the id. The
122        contents of the mask can change, but the mask reference itself should
123        not.
124    """
125
126    def __init__(self, mask):
127        super().__init__()
128        self.register_buffer("mask", mask)
129
130    def forward(self, x):
131        assert self.mask.shape == x.shape
132        return self.mask * x
133
134    def state_dict(self, *args, **kwargs):
135        # We don't want to let the parametrizations to save the mask.
136        # That way we make sure that the linear module doesn't store the masks
137        # alongside their parametrizations.
138        return {}
139