xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/adaround_loss.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Tuple
2
3import numpy as np
4
5import torch
6from torch.nn import functional as F
7
8
9ADAROUND_ZETA: float = 1.1
10ADAROUND_GAMMA: float = -0.1
11
12
13class AdaptiveRoundingLoss(torch.nn.Module):
14    """
15    Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf
16    rounding regularization is eq [24]
17    reconstruction loss is eq [25] except regularization term
18    """
19
20    def __init__(
21        self,
22        max_iter: int,
23        warm_start: float = 0.2,
24        beta_range: Tuple[int, int] = (20, 2),
25        reg_param: float = 0.001,
26    ) -> None:
27        super().__init__()
28        self.max_iter = max_iter
29        self.warm_start = warm_start
30        self.beta_range = beta_range
31        self.reg_param = reg_param
32
33    def rounding_regularization(
34        self,
35        V: torch.Tensor,
36        curr_iter: int,
37    ) -> torch.Tensor:
38        """
39        Major logics copied from official Adaround Implementation.
40        Apply rounding regularization to the input tensor V.
41        """
42        assert (
43            curr_iter < self.max_iter
44        ), "Current iteration strictly les sthan max iteration"
45        if curr_iter < self.warm_start * self.max_iter:
46            return torch.tensor(0.0)
47        else:
48            start_beta, end_beta = self.beta_range
49            warm_start_end_iter = self.warm_start * self.max_iter
50
51            # compute relative iteration of current iteration
52            rel_iter = (curr_iter - warm_start_end_iter) / (
53                self.max_iter - warm_start_end_iter
54            )
55            beta = end_beta + 0.5 * (start_beta - end_beta) * (
56                1 + np.cos(rel_iter * np.pi)
57            )
58
59            # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf
60            h_alpha = torch.clamp(
61                torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA,
62                min=0,
63                max=1,
64            )
65
66            # Apply rounding regularization
67            # This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization.
68            inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta)
69            regularization_term = torch.add(1, -inner_term).sum()
70            return regularization_term * self.reg_param
71
72    def reconstruction_loss(
73        self,
74        soft_quantized_output: torch.Tensor,
75        original_output: torch.Tensor,
76    ) -> torch.Tensor:
77        """
78        Compute the reconstruction loss between the soft quantized output and the original output.
79        """
80        return F.mse_loss(
81            soft_quantized_output, original_output, reduction="none"
82        ).mean()
83
84    def forward(
85        self,
86        soft_quantized_output: torch.Tensor,
87        original_output: torch.Tensor,
88        V: torch.Tensor,
89        curr_iter: int,
90    ) -> Tuple[torch.Tensor, torch.Tensor]:
91        """
92        Compute the asymmetric reconstruction formulation as eq [25]
93        """
94        regularization_term = self.rounding_regularization(V, curr_iter)
95        reconstruction_term = self.reconstruction_loss(
96            soft_quantized_output, original_output
97        )
98        return regularization_term, reconstruction_term
99