xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/stubs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3from torch import nn
4
5
6class QuantStub(nn.Module):
7    r"""Quantize stub module, before calibration, this is same as an observer,
8    it will be swapped as `nnq.Quantize` in `convert`.
9
10    Args:
11        qconfig: quantization configuration for the tensor,
12            if qconfig is not provided, we will get qconfig from parent modules
13    """
14
15    def __init__(self, qconfig=None):
16        super().__init__()
17        if qconfig:
18            self.qconfig = qconfig
19
20    def forward(self, x):
21        return x
22
23
24class DeQuantStub(nn.Module):
25    r"""Dequantize stub module, before calibration, this is same as identity,
26    this will be swapped as `nnq.DeQuantize` in `convert`.
27
28    Args:
29        qconfig: quantization configuration for the tensor,
30            if qconfig is not provided, we will get qconfig from parent modules
31    """
32
33    def __init__(self, qconfig=None):
34        super().__init__()
35        if qconfig:
36            self.qconfig = qconfig
37
38    def forward(self, x):
39        return x
40
41
42class QuantWrapper(nn.Module):
43    r"""A wrapper class that wraps the input module, adds QuantStub and
44    DeQuantStub and surround the call to module with call to quant and dequant
45    modules.
46
47    This is used by the `quantization` utility functions to add the quant and
48    dequant modules, before `convert` function `QuantStub` will just be observer,
49    it observes the input tensor, after `convert`, `QuantStub`
50    will be swapped to `nnq.Quantize` which does actual quantization. Similarly
51    for `DeQuantStub`.
52    """
53    quant: QuantStub
54    dequant: DeQuantStub
55    module: nn.Module
56
57    def __init__(self, module):
58        super().__init__()
59        qconfig = getattr(module, "qconfig", None)
60        self.add_module("quant", QuantStub(qconfig))
61        self.add_module("dequant", DeQuantStub(qconfig))
62        self.add_module("module", module)
63        self.train(module.training)
64
65    def forward(self, X):
66        X = self.quant(X)
67        X = self.module(X)
68        return self.dequant(X)
69