1import torch 2from torch.ao.quantization.observer import UniformQuantizationObserverBase 3 4 5# TODO move to torch/ao/quantization/observer.py. 6class PerChannelParamObserver(UniformQuantizationObserverBase): 7 """ 8 Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325 9 """ 10 11 def __init__( 12 self, 13 ch_axis=0, 14 use_mse=True, 15 steps=100, 16 dtype=torch.int8, 17 qscheme=torch.per_channel_symmetric, 18 reduce_range=False, 19 quant_min=None, 20 quant_max=None, 21 factory_kwargs=None, 22 eps=torch.finfo(torch.float32).eps, # noqa: B008 23 is_dynamic=False, 24 **kwargs, 25 ) -> None: 26 super().__init__( 27 dtype=dtype, 28 qscheme=qscheme, 29 reduce_range=reduce_range, 30 quant_min=quant_min, 31 quant_max=quant_max, 32 factory_kwargs=factory_kwargs, 33 eps=eps, 34 is_dynamic=is_dynamic, 35 **kwargs, 36 ) 37 38 factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) 39 self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) 40 self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) 41 self.ch_axis = ch_axis 42 self.use_mse = use_mse 43 self.steps = steps 44 self.calibrated = False 45 46 def to_ch_axis(self, x): 47 axis_order = list(range(len(x.size()))) 48 axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis 49 return torch.flatten(x.permute(axis_order), start_dim=1) 50 51 def mse(self, pred, expect): 52 loss = (pred - expect).abs().pow(2) 53 return self.to_ch_axis(loss).mean(1) 54 55 def cosine(self, pred, expect): 56 target = torch.ones(pred.shape[self.ch_axis]) 57 pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) 58 expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) 59 return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) 60 61 def loss_fn(self, x, new_min, new_max): 62 scale, offset = self._calculate_qparams(new_min, new_max) 63 x_q = torch.fake_quantize_per_channel_affine( 64 x, 65 scale.data, 66 offset.data.int(), 67 self.ch_axis, 68 self.quant_min, 69 self.quant_max, 70 ) 71 return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) 72 73 def line_search(self, x): 74 x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) 75 x_range = torch.max(x_min.abs(), x_max) 76 optimal_loss = torch.zeros_like(x_min) + 1e9 77 78 # check which clip range could produce smallest loss 79 for i in range(1, self.steps + 1): 80 thres = x_range / self.steps * i 81 current_loss = self.loss_fn(x, -thres, thres) 82 x_min = torch.where(current_loss < optimal_loss, -thres, x_min) 83 x_max = torch.where(current_loss < optimal_loss, thres, x_max) 84 optimal_loss = torch.min(current_loss, optimal_loss) 85 86 return x_min, x_max 87 88 def forward(self, x_orig): 89 # since params are static, one calibration is enough 90 if not self.calibrated: 91 x = x_orig.detach().to(self.min_val.dtype) 92 self.min_val, self.max_val = self.line_search(x) 93 self.calibrated = True 94 95 # return fake-quant result for saturating outliers 96 scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) 97 return torch.fake_quantize_per_channel_affine( 98 x_orig, 99 scale.data, 100 zero_point.data.int(), 101 self.ch_axis, 102 self.quant_min, 103 self.quant_max, 104 ) 105 106 @torch.jit.export 107 def calculate_qparams(self): 108 return self._calculate_qparams(self.min_val, self.max_val) 109