1import torch 2from torch.ao.quantization import MinMaxObserver 3from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize 4from torch.ao.quantization.fake_quantize import FakeQuantize 5from torch.ao.quantization.qconfig import QConfig 6 7 8""" 9Default symmetric fake_quant for activations. 10""" 11default_symmetric_fake_quant = FakeQuantize.with_args( 12 observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 13) 14 15""" 16Default symmetric fake_quant for weights. 17""" 18default_weight_symmetric_fake_quant = FakeQuantize.with_args( 19 observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 20) 21 22# uniform activation and weight, b=8 k=2 23uniform_qconfig_8bit = QConfig( 24 activation=default_symmetric_fake_quant, 25 weight=default_weight_symmetric_fake_quant.with_args, 26) 27 28# uniform activation, APoT weight, b=8 k=2 29apot_weight_qconfig_8bit = QConfig( 30 activation=default_symmetric_fake_quant.with_args, 31 weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8), 32) 33 34# APoT activation and uniform weight, b=8 k=2 35apot_qconfig_8bit = QConfig( 36 activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8), 37 weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8), 38) 39 40# uniform activation and weight, b=4 k=2 41uniform_qconfig_4bit = QConfig( 42 activation=default_symmetric_fake_quant.with_args(quant_min=0, quant_max=15), 43 weight=default_weight_symmetric_fake_quant.with_args(quant_min=0, quant_max=15), 44) 45 46# uniform activation, APoT weight, b=4 k=2 47apot_weight_qconfig_4bit = QConfig( 48 activation=default_symmetric_fake_quant.with_args(quant_min=0, quant_max=15), 49 weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8), 50) 51 52# APoT activation and uniform weight, b=4 k=2 53apot_qconfig_4bit = QConfig( 54 activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8), 55 weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8), 56) 57