xref: /aosp_15_r20/external/pytorch/torch/ao/nn/sparse/quantized/dynamic/linear.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional
3
4import torch
5import torch.ao.nn.intrinsic as nni
6from torch.ao.nn.quantized.modules.utils import (
7    _hide_packed_params_repr,
8    _quantize_weight,
9)
10from torch.ao.nn.sparse.quantized import linear
11from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern
12
13
14__all__ = ["Linear"]
15
16
17class Linear(torch.nn.Module):
18    r"""
19    A dynamically quantized sparse linear module with float tensor as inputs and outputs.
20    """
21    _version = 1
22    _op_type = "sparse_dynamic"
23    _FLOAT_MODULE = torch.nn.Linear
24
25    def __init__(
26        self,
27        in_features,
28        out_features,
29        row_block_size,
30        col_block_size,
31        bias=True,
32        dtype=torch.qint8,
33    ):
34        super().__init__()
35
36        if dtype != torch.qint8:
37            raise NotImplementedError(
38                "Only QINT8 is supported for Sparse Quantized Linear Dynamic"
39            )
40
41        self.in_features = in_features
42        self.out_features = out_features
43
44        if bias:
45            bias = torch.zeros(self.out_features, dtype=torch.float)
46        else:
47            bias = None
48
49        qweight = torch._empty_affine_quantized(
50            [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8
51        )
52        self._packed_params = linear.LinearPackedParams(
53            row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype
54        )
55        self._packed_params.set_weight_bias(
56            qweight, bias, row_block_size, col_block_size
57        )
58
59    def _get_name(self):
60        return "SparseQuantizedDynamicLinear"
61
62    def extra_repr(self):
63        return f"in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}"
64
65    def __repr__(self):
66        return _hide_packed_params_repr(self, linear.LinearPackedParams)
67
68    def forward(self, x: torch.Tensor) -> torch.Tensor:
69        return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params)
70
71    def _save_to_state_dict(self, destination, prefix, keep_vars):
72        super()._save_to_state_dict(destination, prefix, keep_vars)
73        destination[prefix + "op_type"] = self._op_type
74
75    def _load_from_state_dict(
76        self,
77        state_dict,
78        prefix,
79        local_metadata,
80        strict,
81        missing_keys,
82        unexpected_keys,
83        error_msgs,
84    ):
85        op_type = int(state_dict[prefix + "op_type"])
86        assert (
87            op_type == "sparse"
88        ), f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]"
89        state_dict.pop(prefix + "op_type")
90
91        version = local_metadata.get("version", None)
92        assert version <= self._version
93
94        # Is this code valid? In old quantization it seemed to be used to load
95        # older model
96        weight = state_dict.pop(prefix + "weight")
97        bias = state_dict.pop(prefix + "bias")
98        state_dict.update(
99            {
100                prefix + "_packed_params.weight": weight,
101                prefix + "_packed_params.bias": bias,
102            }
103        )
104
105        super()._load_from_state_dict(
106            state_dict,
107            prefix,
108            local_metadata,
109            False,
110            missing_keys,
111            unexpected_keys,
112            error_msgs,
113        )
114
115    def _weight_bias(self):
116        return self._packed_params._weight_bias()
117
118    def weight(self):
119        return self._weight_bias()[0]
120
121    def bias(self):
122        return self._weight_bias()[1]
123
124    def set_weight_bias(
125        self,
126        w: torch.Tensor,
127        b: Optional[torch.Tensor],
128        row_block_size: Optional[int],
129        col_block_size: Optional[int],
130    ) -> None:
131        assert row_block_size is not None and col_block_size is not None
132        self.out_features = w.shape[0]
133        self.in_features = w.shape[1]
134        self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
135
136    @classmethod
137    def from_float(cls, mod, use_precomputed_fake_quant=False):
138        r"""Create a quantized sparse dynamic module from a float module.
139
140        We only care about the convert at this stage, no need for observers just yet.
141        """
142        assert type(mod) == cls._FLOAT_MODULE, (
143            " nnq."
144            + cls.__name__
145            + ".from_float only works for "
146            + cls._FLOAT_MODULE.__name__
147        )
148        # TODO: Need to add options to qconfig to avoid the calibration.
149        # TODO: Add calibration for the sparsity
150        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
151        if type(mod) == nni.LinearReLU:
152            mod = mod[0]
153        if mod.qconfig is not None and mod.qconfig.weight is not None:
154            weight_observer = mod.qconfig.weight()
155        else:
156            # We have the circular import issues if we import the qconfig in the beginning of this file:
157            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
158            # import until we need it.
159            from torch.ao.quantization.qconfig import default_dynamic_qconfig
160
161            weight_observer = default_dynamic_qconfig.weight()
162
163        # It is important to multiply by the mask BEFORE calling the `weight_observer`
164        # TODO (zaf): Mask might not be part of the qconfig (T83295194)
165        weight = mod.weight
166        if getattr(mod.qconfig, "mask", False):
167            weight = mod.qconfig.mask * mod.weight
168
169        weight_observer(weight)
170        dtype = weight_observer.dtype
171        assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
172        w_sc, w_zp = weight_observer.calculate_qparams()
173        if isinstance(w_zp, torch.Tensor):
174            assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
175        else:
176            assert w_zp == 0, "Weight zero point must map to 0"
177        qweight = _quantize_weight(weight.float(), weight_observer)
178
179        row_block_size, col_block_size = LinearBlockSparsePattern.block_size()
180        qlinear = cls(
181            mod.in_features,
182            mod.out_features,
183            row_block_size,
184            col_block_size,
185            dtype=dtype,
186        )
187        qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
188        return qlinear
189