1from typing import Tuple 2 3import numpy as np 4 5import torch 6from torch.nn import functional as F 7 8 9ADAROUND_ZETA: float = 1.1 10ADAROUND_GAMMA: float = -0.1 11 12 13class AdaptiveRoundingLoss(torch.nn.Module): 14 """ 15 Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf 16 rounding regularization is eq [24] 17 reconstruction loss is eq [25] except regularization term 18 """ 19 20 def __init__( 21 self, 22 max_iter: int, 23 warm_start: float = 0.2, 24 beta_range: Tuple[int, int] = (20, 2), 25 reg_param: float = 0.001, 26 ) -> None: 27 super().__init__() 28 self.max_iter = max_iter 29 self.warm_start = warm_start 30 self.beta_range = beta_range 31 self.reg_param = reg_param 32 33 def rounding_regularization( 34 self, 35 V: torch.Tensor, 36 curr_iter: int, 37 ) -> torch.Tensor: 38 """ 39 Major logics copied from official Adaround Implementation. 40 Apply rounding regularization to the input tensor V. 41 """ 42 assert ( 43 curr_iter < self.max_iter 44 ), "Current iteration strictly les sthan max iteration" 45 if curr_iter < self.warm_start * self.max_iter: 46 return torch.tensor(0.0) 47 else: 48 start_beta, end_beta = self.beta_range 49 warm_start_end_iter = self.warm_start * self.max_iter 50 51 # compute relative iteration of current iteration 52 rel_iter = (curr_iter - warm_start_end_iter) / ( 53 self.max_iter - warm_start_end_iter 54 ) 55 beta = end_beta + 0.5 * (start_beta - end_beta) * ( 56 1 + np.cos(rel_iter * np.pi) 57 ) 58 59 # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf 60 h_alpha = torch.clamp( 61 torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA, 62 min=0, 63 max=1, 64 ) 65 66 # Apply rounding regularization 67 # This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization. 68 inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta) 69 regularization_term = torch.add(1, -inner_term).sum() 70 return regularization_term * self.reg_param 71 72 def reconstruction_loss( 73 self, 74 soft_quantized_output: torch.Tensor, 75 original_output: torch.Tensor, 76 ) -> torch.Tensor: 77 """ 78 Compute the reconstruction loss between the soft quantized output and the original output. 79 """ 80 return F.mse_loss( 81 soft_quantized_output, original_output, reduction="none" 82 ).mean() 83 84 def forward( 85 self, 86 soft_quantized_output: torch.Tensor, 87 original_output: torch.Tensor, 88 V: torch.Tensor, 89 curr_iter: int, 90 ) -> Tuple[torch.Tensor, torch.Tensor]: 91 """ 92 Compute the asymmetric reconstruction formulation as eq [25] 93 """ 94 regularization_term = self.rounding_regularization(V, curr_iter) 95 reconstruction_term = self.reconstruction_loss( 96 soft_quantized_output, original_output 97 ) 98 return regularization_term, reconstruction_term 99