1# mypy: allow-untyped-defs 2import abc 3import collections 4import itertools 5 6import torch 7from torch.nn.modules.module import _addindent 8 9 10__all__ = [ 11 "WeightedQuantizedModule", 12] 13 14 15class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta): 16 """Wrapper for quantized modules than can be lowered from reference modules.""" 17 18 @classmethod 19 @abc.abstractmethod 20 def from_reference(cls, ref_module, output_scale, output_zero_point): 21 raise NotImplementedError 22 23 24def _get_weight_observer(observer): 25 # FakeQuantize observer 26 if hasattr(observer, "activation_post_process"): 27 observer = observer.activation_post_process 28 # UniformQuantizationObserverBase observer 29 return observer 30 31 32def _needs_weight_clamping(observer, dtype): 33 observer = _get_weight_observer(observer) 34 if dtype in [torch.qint8, torch.quint8, torch.qint32]: 35 info = torch.iinfo(dtype) 36 return observer.quant_min > info.min or observer.quant_max < info.max 37 return False 38 39 40def _clamp_weights(qweight, observer, scale, zp): 41 if not _needs_weight_clamping(observer, qweight.dtype): 42 return qweight 43 44 observer = _get_weight_observer(observer) 45 min_, max_ = observer.quant_min, observer.quant_max 46 47 # Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet. 48 qw_int_max = torch.clone(qweight.int_repr()).fill_(max_) 49 qw_int_min = torch.clone(qweight.int_repr()).fill_(min_) 50 qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max) 51 52 if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: 53 qweight = torch._make_per_tensor_quantized_tensor( 54 qw_int, scale.item(), zp.item() 55 ) 56 elif observer.qscheme in [ 57 torch.per_channel_symmetric, 58 torch.per_channel_affine, 59 torch.per_channel_affine_float_qparams, 60 ]: 61 qweight = torch._make_per_channel_quantized_tensor( 62 qw_int, scale, zp, axis=observer.ch_axis 63 ) 64 else: 65 raise ValueError("Unexpected qscheme " + observer.qscheme) 66 return qweight 67 68 69def _quantize_weight(float_wt, observer): 70 wt_scale, wt_zp = observer.calculate_qparams() 71 if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: 72 qweight = torch.quantize_per_tensor( 73 float_wt, float(wt_scale), int(wt_zp), torch.qint8 74 ) 75 qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) 76 elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]: 77 wt_axis = observer.ch_axis 78 qweight = torch.quantize_per_channel( 79 float_wt, 80 wt_scale.to(torch.double), 81 wt_zp.to(torch.int64), 82 wt_axis, 83 torch.qint8, 84 ) 85 qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) 86 elif observer.qscheme in [torch.per_channel_affine_float_qparams]: 87 qweight = torch.quantize_per_channel( 88 float_wt, 89 wt_scale.to(torch.float), 90 wt_zp.to(torch.float), 91 observer.ch_axis, 92 observer.dtype, 93 ) 94 qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) 95 else: 96 raise ValueError("Unexpected qscheme " + observer.qscheme) 97 return qweight 98 99 100def _ntuple_from_first(n): 101 """Converts the argument to a tuple of size n 102 with the first element repeated.""" 103 104 def parse(x): 105 while isinstance(x, collections.abc.Sequence): 106 if len(x) == n: 107 break 108 x = x[0] 109 return tuple(itertools.repeat(x, n)) 110 111 return parse 112 113 114def _hide_packed_params_repr(self, params): 115 # We don't want to show `PackedParams` children, hence custom 116 # `__repr__`. This is the same as nn.Module.__repr__, except the check 117 # for the `params module`. 118 extra_lines = [] 119 extra_repr = self.extra_repr() 120 # empty string will be split into list [''] 121 if extra_repr: 122 extra_lines = extra_repr.split("\n") 123 child_lines = [] 124 for key, module in self._modules.items(): 125 if isinstance(module, params): 126 continue 127 mod_str = repr(module) 128 mod_str = _addindent(mod_str, 2) 129 child_lines.append("(" + key + "): " + mod_str) 130 lines = extra_lines + child_lines 131 132 main_str = self._get_name() + "(" 133 if lines: 134 # simple one-liner info, which most builtin Modules will use 135 if len(extra_lines) == 1 and not child_lines: 136 main_str += extra_lines[0] 137 else: 138 main_str += "\n " + "\n ".join(lines) + "\n" 139 140 main_str += ")" 141 return main_str 142 143 144_pair_from_first = _ntuple_from_first(2) 145