xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/observers/per_channel_param_observer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1import torch
2from torch.ao.quantization.observer import UniformQuantizationObserverBase
3
4
5# TODO move to torch/ao/quantization/observer.py.
6class PerChannelParamObserver(UniformQuantizationObserverBase):
7    """
8    Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325
9    """
10
11    def __init__(
12        self,
13        ch_axis=0,
14        use_mse=True,
15        steps=100,
16        dtype=torch.int8,
17        qscheme=torch.per_channel_symmetric,
18        reduce_range=False,
19        quant_min=None,
20        quant_max=None,
21        factory_kwargs=None,
22        eps=torch.finfo(torch.float32).eps,  # noqa: B008
23        is_dynamic=False,
24        **kwargs,
25    ) -> None:
26        super().__init__(
27            dtype=dtype,
28            qscheme=qscheme,
29            reduce_range=reduce_range,
30            quant_min=quant_min,
31            quant_max=quant_max,
32            factory_kwargs=factory_kwargs,
33            eps=eps,
34            is_dynamic=is_dynamic,
35            **kwargs,
36        )
37
38        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
39        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
40        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
41        self.ch_axis = ch_axis
42        self.use_mse = use_mse
43        self.steps = steps
44        self.calibrated = False
45
46    def to_ch_axis(self, x):
47        axis_order = list(range(len(x.size())))
48        axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis
49        return torch.flatten(x.permute(axis_order), start_dim=1)
50
51    def mse(self, pred, expect):
52        loss = (pred - expect).abs().pow(2)
53        return self.to_ch_axis(loss).mean(1)
54
55    def cosine(self, pred, expect):
56        target = torch.ones(pred.shape[self.ch_axis])
57        pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1)
58        expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1)
59        return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target)
60
61    def loss_fn(self, x, new_min, new_max):
62        scale, offset = self._calculate_qparams(new_min, new_max)
63        x_q = torch.fake_quantize_per_channel_affine(
64            x,
65            scale.data,
66            offset.data.int(),
67            self.ch_axis,
68            self.quant_min,
69            self.quant_max,
70        )
71        return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x)
72
73    def line_search(self, x):
74        x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1)
75        x_range = torch.max(x_min.abs(), x_max)
76        optimal_loss = torch.zeros_like(x_min) + 1e9
77
78        # check which clip range could produce smallest loss
79        for i in range(1, self.steps + 1):
80            thres = x_range / self.steps * i
81            current_loss = self.loss_fn(x, -thres, thres)
82            x_min = torch.where(current_loss < optimal_loss, -thres, x_min)
83            x_max = torch.where(current_loss < optimal_loss, thres, x_max)
84            optimal_loss = torch.min(current_loss, optimal_loss)
85
86        return x_min, x_max
87
88    def forward(self, x_orig):
89        # since params are static, one calibration is enough
90        if not self.calibrated:
91            x = x_orig.detach().to(self.min_val.dtype)
92            self.min_val, self.max_val = self.line_search(x)
93            self.calibrated = True
94
95        # return fake-quant result for saturating outliers
96        scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
97        return torch.fake_quantize_per_channel_affine(
98            x_orig,
99            scale.data,
100            zero_point.data.int(),
101            self.ch_axis,
102            self.quant_min,
103            self.quant_max,
104        )
105
106    @torch.jit.export
107    def calculate_qparams(self):
108        return self._calculate_qparams(self.min_val, self.max_val)
109