xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/dynamic/modules/linear.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.ao.nn.intrinsic as nni
4import torch.ao.nn.quantized as nnq
5from torch.ao.nn.quantized.modules.utils import _quantize_weight
6
7
8__all__ = [
9    "Linear",
10]
11
12
13class Linear(nnq.Linear):
14    r"""
15    A dynamic quantized linear module with floating point tensor as inputs and outputs.
16    We adopt the same interface as `torch.nn.Linear`, please see
17    https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
18
19    Similar to :class:`torch.nn.Linear`, attributes will be randomly
20    initialized at module creation time and will be overwritten later
21
22    Attributes:
23        weight (Tensor): the non-learnable quantized weights of the module which are of
24                         shape :math:`(\text{out\_features}, \text{in\_features})`.
25        bias (Tensor): the non-learnable floating point bias of the module of shape
26                       :math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
27                       the values are initialized to zero.
28
29    Examples::
30
31        >>> # xdoctest: +SKIP
32        >>> m = nn.quantized.dynamic.Linear(20, 30)
33        >>> input = torch.randn(128, 20)
34        >>> output = m(input)
35        >>> print(output.size())
36        torch.Size([128, 30])
37    """
38    # version used in this class is different from the parent class nnq.Linear
39    _version = 4
40
41    def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
42        super().__init__(in_features, out_features, bias_, dtype=dtype)
43        # We don't muck around with buffers or attributes or anything here
44        # to keep the module simple. *everything* is simply a Python attribute.
45        # Serialization logic is explicitly handled in the below serialization and
46        # deserialization modules
47        self.version = 4
48
49    def forward(self, x):
50        # Note that we can handle self.bias == None case.
51        if self._packed_params.dtype == torch.qint8:
52            if self.version is None or self.version < 4:
53                Y = torch.ops.quantized.linear_dynamic(
54                    x, self._packed_params._packed_params
55                )
56            else:
57                Y = torch.ops.quantized.linear_dynamic(
58                    x, self._packed_params._packed_params, reduce_range=True
59                )
60        elif self._packed_params.dtype == torch.float16:
61            Y = torch.ops.quantized.linear_dynamic_fp16(
62                x, self._packed_params._packed_params
63            )
64        else:
65            raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
66        return Y.to(x.dtype)
67
68    def _get_name(self):
69        return "DynamicQuantizedLinear"
70
71    def extra_repr(self):
72        extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}"
73        if self._packed_params.dtype == torch.qint8:
74            extra_repr_str += f", qscheme={self.weight().qscheme()}"
75        return extra_repr_str
76
77    def _load_from_state_dict(
78        self,
79        state_dict,
80        prefix,
81        local_metadata,
82        strict,
83        missing_keys,
84        unexpected_keys,
85        error_msgs,
86    ):
87        version = local_metadata.get("version", None)
88        self.version = version
89        super()._load_from_state_dict(
90            state_dict,
91            prefix,
92            local_metadata,
93            False,
94            missing_keys,
95            unexpected_keys,
96            error_msgs,
97        )
98
99    @classmethod
100    def from_float(cls, mod, use_precomputed_fake_quant=False):
101        r"""Create a dynamic quantized module from a float module or qparams_dict
102
103        Args:
104            mod (Module): a float module, either produced by torch.ao.quantization
105                          utilities or provided by the user
106        """
107        float_modules = [
108            torch.nn.Linear,
109            torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
110            torch.ao.nn.intrinsic.modules.fused.LinearReLU,
111            torch.ao.nn.qat.dynamic.Linear,
112        ]
113
114        assert (
115            type(mod) in float_modules
116        ), "nn.quantized.dynamic.Linear.from_float only works for one of" + str(
117            [float_mod.__name__ for float_mod in float_modules]
118        )
119        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
120        if type(mod) == nni.LinearReLU:
121            mod = mod[0]
122        if mod.qconfig is not None and mod.qconfig.weight is not None:
123            weight_observer = mod.qconfig.weight()
124        else:
125            # We have the circular import issues if we import the qconfig in the beginning of this file:
126            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
127            # import until we need it.
128            from torch.ao.quantization.qconfig import default_dynamic_qconfig
129
130            weight_observer = default_dynamic_qconfig.weight()
131        dtype = weight_observer.dtype
132        assert dtype in [torch.qint8, torch.float16], (
133            "The only supported dtypes for "
134            f"dynamic quantized linear are qint8 and float16 got: {dtype}"
135        )
136        weight_observer(mod.weight)
137        if dtype == torch.qint8:
138            qweight = _quantize_weight(mod.weight.float(), weight_observer)
139        elif dtype == torch.float16:
140            qweight = mod.weight.float()
141        else:
142            raise RuntimeError(
143                "Unsupported dtype specified for dynamic quantized Linear!"
144            )
145        qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
146        qlinear.set_weight_bias(qweight, mod.bias)
147        return qlinear
148
149    @classmethod
150    def from_reference(cls, ref_qlinear):
151        """Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
152        module
153        Args:
154            ref_qlinear (Module): a reference quantized  module, either produced by
155            torch.ao.quantization functions or provided by the user
156        """
157        qlinear = cls(
158            ref_qlinear.in_features,
159            ref_qlinear.out_features,
160            dtype=ref_qlinear.weight_dtype,
161        )
162        qweight = ref_qlinear.get_quantized_weight()
163        bias = ref_qlinear.bias
164        qlinear.set_weight_bias(qweight, bias)
165        return qlinear
166