xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.ao.nn.intrinsic as nni
4import torch.ao.nn.quantized.dynamic as nnqd
5
6
7__all__ = ["LinearReLU"]
8
9
10class LinearReLU(nnqd.Linear):
11    r"""
12    A LinearReLU module fused from Linear and ReLU modules that can be used
13    for dynamic quantization.
14    Supports both, FP16 and INT8 quantization.
15
16    We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`.
17
18    Attributes:
19        Same as torch.ao.nn.quantized.dynamic.Linear
20
21    Examples::
22
23        >>> # xdoctest: +SKIP
24        >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
25        >>> input = torch.randn(128, 20)
26        >>> output = m(input)
27        >>> print(output.size())
28        torch.Size([128, 30])
29    """
30    _FLOAT_MODULE = nni.LinearReLU  # type: ignore[assignment]
31
32    def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
33        super().__init__(in_features, out_features, bias, dtype)
34
35    def forward(self, x: torch.Tensor) -> torch.Tensor:
36        if self._packed_params.dtype == torch.qint8:
37            # TODO check if we should set reduce_rage = True by default here
38            Y = torch.ops.quantized.linear_relu_dynamic(
39                x, self._packed_params._packed_params, reduce_range=True
40            )
41        elif self._packed_params.dtype == torch.float16:
42            Y = torch.ops.quantized.linear_relu_dynamic_fp16(
43                x, self._packed_params._packed_params
44            )
45        else:
46            raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!")
47        return Y.to(x.dtype)
48
49    def _get_name(self):
50        return "DynamicQuantizedLinearReLU"
51
52    @classmethod
53    def from_float(cls, mod, use_precomputed_fake_quant=False):
54        return super().from_float(
55            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
56        )
57
58    @classmethod
59    def from_reference(cls, ref_qlinear_relu):
60        return super().from_reference(ref_qlinear_relu[0])
61