xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/fake_quantize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch import Tensor
4from torch.ao.quantization.experimental.fake_quantize_function import (
5    fake_quantize_function,
6)
7from torch.ao.quantization.experimental.observer import APoTObserver
8from torch.ao.quantization.fake_quantize import FakeQuantizeBase
9
10
11class APoTFakeQuantize(FakeQuantizeBase):
12    alpha: Tensor
13    gamma: Tensor
14    quantization_levels: Tensor
15    level_indices: Tensor
16
17    def __init__(self, observer=APoTObserver, **observer_kwargs):
18        super().__init__()
19        self.activation_post_process = observer(**observer_kwargs)
20        self.dtype = self.activation_post_process.dtype
21
22    def calculate_qparams(self, signed=False):  # type: ignore[override]
23        return self.activation_post_process.calculate_qparams(signed=signed)
24
25    def forward(self, X: torch.Tensor):  # type: ignore[override]
26        if self.observer_enabled[0] == 1:
27            self.activation_post_process.forward(X)
28            result = self.activation_post_process.calculate_qparams(signed=False)
29            self.alpha = result[0]
30            self.gamma = result[1]
31            self.quantization_levels = result[2]
32            self.level_indices = result[3]
33
34        if self.fake_quant_enabled[0] == 1:
35            assert (
36                self.alpha is not None
37                and self.gamma is not None
38                and self.quantization_levels is not None
39                and self.level_indices is not None
40            ), "Must set qparams for fake quant"
41
42            X = fake_quantize_function.apply(
43                X, self.alpha, self.gamma, self.quantization_levels, self.level_indices
44            )
45
46        return X
47