xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/reference/modules/linear.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, Optional
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7
8from .utils import ReferenceQuantizedModule
9
10
11__all__ = ["Linear"]
12
13
14class Linear(nn.Linear, ReferenceQuantizedModule):
15    """A reference quantized linear module that fits into the FX
16    Graph Mode Quantization workflow
17    activation will be floating point Tensor, we will store floating
18    point weight as well in the module, but in forward we'll quantize
19    and dequantize the weight before running the floating point functional
20    linear operator.
21    """
22
23    _IS_REFERENCE = True
24
25    def __init__(
26        self,
27        in_features: int,
28        out_features: int,
29        bias_: bool = True,
30        device: Optional[torch.device] = None,
31        dtype: Optional[torch.dtype] = None,
32        weight_qparams: Optional[Dict[str, Any]] = None,
33    ):
34        super().__init__(in_features, out_features, bias_, device, dtype)
35        self._init_weight_qparams(weight_qparams, device)
36
37    def _get_name(self):
38        return "QuantizedLinear(Reference)"
39
40    def forward(self, x: torch.Tensor) -> torch.Tensor:
41        """
42        we have:
43        w(float) -- quant - dequant \
44        x(float) ------------- F.linear ---
45
46        In the full model, we will see
47        w(float) -- quant - *dequant \
48        x -- quant --- *dequant --  *F.linear --- *quant - dequant
49        and the backend should be able to fuse the ops with `*` into a quantized linear
50        """
51        weight_quant_dequant = self.get_weight()
52        result = F.linear(x, weight_quant_dequant, self.bias)
53        return result
54
55    @classmethod
56    def from_float(cls, float_linear, weight_qparams):
57        qref_linear = Linear(
58            float_linear.in_features,
59            float_linear.out_features,
60            float_linear.bias is not None,
61            device=float_linear.weight.device,
62            dtype=float_linear.weight.dtype,
63            weight_qparams=weight_qparams,
64        )
65        qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
66        if float_linear.bias is not None:
67            qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
68        return qref_linear
69