1# mypy: allow-untyped-defs 2import torch 3from torch import Tensor 4from torch.ao.quantization.experimental.fake_quantize_function import ( 5 fake_quantize_function, 6) 7from torch.ao.quantization.experimental.observer import APoTObserver 8from torch.ao.quantization.fake_quantize import FakeQuantizeBase 9 10 11class APoTFakeQuantize(FakeQuantizeBase): 12 alpha: Tensor 13 gamma: Tensor 14 quantization_levels: Tensor 15 level_indices: Tensor 16 17 def __init__(self, observer=APoTObserver, **observer_kwargs): 18 super().__init__() 19 self.activation_post_process = observer(**observer_kwargs) 20 self.dtype = self.activation_post_process.dtype 21 22 def calculate_qparams(self, signed=False): # type: ignore[override] 23 return self.activation_post_process.calculate_qparams(signed=signed) 24 25 def forward(self, X: torch.Tensor): # type: ignore[override] 26 if self.observer_enabled[0] == 1: 27 self.activation_post_process.forward(X) 28 result = self.activation_post_process.calculate_qparams(signed=False) 29 self.alpha = result[0] 30 self.gamma = result[1] 31 self.quantization_levels = result[2] 32 self.level_indices = result[3] 33 34 if self.fake_quant_enabled[0] == 1: 35 assert ( 36 self.alpha is not None 37 and self.gamma is not None 38 and self.quantization_levels is not None 39 and self.level_indices is not None 40 ), "Must set qparams for fake quant" 41 42 X = fake_quantize_function.apply( 43 X, self.alpha, self.gamma, self.quantization_levels, self.level_indices 44 ) 45 46 return X 47