1# mypy: allow-untyped-defs 2import torch 3import torch.ao.nn.intrinsic as nni 4import torch.ao.nn.quantized as nnq 5from torch.ao.nn.quantized.modules.utils import _quantize_weight 6 7 8__all__ = [ 9 "Linear", 10] 11 12 13class Linear(nnq.Linear): 14 r""" 15 A dynamic quantized linear module with floating point tensor as inputs and outputs. 16 We adopt the same interface as `torch.nn.Linear`, please see 17 https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. 18 19 Similar to :class:`torch.nn.Linear`, attributes will be randomly 20 initialized at module creation time and will be overwritten later 21 22 Attributes: 23 weight (Tensor): the non-learnable quantized weights of the module which are of 24 shape :math:`(\text{out\_features}, \text{in\_features})`. 25 bias (Tensor): the non-learnable floating point bias of the module of shape 26 :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, 27 the values are initialized to zero. 28 29 Examples:: 30 31 >>> # xdoctest: +SKIP 32 >>> m = nn.quantized.dynamic.Linear(20, 30) 33 >>> input = torch.randn(128, 20) 34 >>> output = m(input) 35 >>> print(output.size()) 36 torch.Size([128, 30]) 37 """ 38 # version used in this class is different from the parent class nnq.Linear 39 _version = 4 40 41 def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): 42 super().__init__(in_features, out_features, bias_, dtype=dtype) 43 # We don't muck around with buffers or attributes or anything here 44 # to keep the module simple. *everything* is simply a Python attribute. 45 # Serialization logic is explicitly handled in the below serialization and 46 # deserialization modules 47 self.version = 4 48 49 def forward(self, x): 50 # Note that we can handle self.bias == None case. 51 if self._packed_params.dtype == torch.qint8: 52 if self.version is None or self.version < 4: 53 Y = torch.ops.quantized.linear_dynamic( 54 x, self._packed_params._packed_params 55 ) 56 else: 57 Y = torch.ops.quantized.linear_dynamic( 58 x, self._packed_params._packed_params, reduce_range=True 59 ) 60 elif self._packed_params.dtype == torch.float16: 61 Y = torch.ops.quantized.linear_dynamic_fp16( 62 x, self._packed_params._packed_params 63 ) 64 else: 65 raise RuntimeError("Unsupported dtype on dynamic quantized linear!") 66 return Y.to(x.dtype) 67 68 def _get_name(self): 69 return "DynamicQuantizedLinear" 70 71 def extra_repr(self): 72 extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}" 73 if self._packed_params.dtype == torch.qint8: 74 extra_repr_str += f", qscheme={self.weight().qscheme()}" 75 return extra_repr_str 76 77 def _load_from_state_dict( 78 self, 79 state_dict, 80 prefix, 81 local_metadata, 82 strict, 83 missing_keys, 84 unexpected_keys, 85 error_msgs, 86 ): 87 version = local_metadata.get("version", None) 88 self.version = version 89 super()._load_from_state_dict( 90 state_dict, 91 prefix, 92 local_metadata, 93 False, 94 missing_keys, 95 unexpected_keys, 96 error_msgs, 97 ) 98 99 @classmethod 100 def from_float(cls, mod, use_precomputed_fake_quant=False): 101 r"""Create a dynamic quantized module from a float module or qparams_dict 102 103 Args: 104 mod (Module): a float module, either produced by torch.ao.quantization 105 utilities or provided by the user 106 """ 107 float_modules = [ 108 torch.nn.Linear, 109 torch.nn.modules.linear.NonDynamicallyQuantizableLinear, 110 torch.ao.nn.intrinsic.modules.fused.LinearReLU, 111 torch.ao.nn.qat.dynamic.Linear, 112 ] 113 114 assert ( 115 type(mod) in float_modules 116 ), "nn.quantized.dynamic.Linear.from_float only works for one of" + str( 117 [float_mod.__name__ for float_mod in float_modules] 118 ) 119 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" 120 if type(mod) == nni.LinearReLU: 121 mod = mod[0] 122 if mod.qconfig is not None and mod.qconfig.weight is not None: 123 weight_observer = mod.qconfig.weight() 124 else: 125 # We have the circular import issues if we import the qconfig in the beginning of this file: 126 # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the 127 # import until we need it. 128 from torch.ao.quantization.qconfig import default_dynamic_qconfig 129 130 weight_observer = default_dynamic_qconfig.weight() 131 dtype = weight_observer.dtype 132 assert dtype in [torch.qint8, torch.float16], ( 133 "The only supported dtypes for " 134 f"dynamic quantized linear are qint8 and float16 got: {dtype}" 135 ) 136 weight_observer(mod.weight) 137 if dtype == torch.qint8: 138 qweight = _quantize_weight(mod.weight.float(), weight_observer) 139 elif dtype == torch.float16: 140 qweight = mod.weight.float() 141 else: 142 raise RuntimeError( 143 "Unsupported dtype specified for dynamic quantized Linear!" 144 ) 145 qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) 146 qlinear.set_weight_bias(qweight, mod.bias) 147 return qlinear 148 149 @classmethod 150 def from_reference(cls, ref_qlinear): 151 """Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized 152 module 153 Args: 154 ref_qlinear (Module): a reference quantized module, either produced by 155 torch.ao.quantization functions or provided by the user 156 """ 157 qlinear = cls( 158 ref_qlinear.in_features, 159 ref_qlinear.out_features, 160 dtype=ref_qlinear.weight_dtype, 161 ) 162 qweight = ref_qlinear.get_quantized_weight() 163 bias = ref_qlinear.bias 164 qlinear.set_weight_bias(qweight, bias) 165 return qlinear 166