1# mypy: allow-untyped-defs 2from typing import Any, Dict, Optional 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7 8from .utils import ReferenceQuantizedModule 9 10 11__all__ = ["Linear"] 12 13 14class Linear(nn.Linear, ReferenceQuantizedModule): 15 """A reference quantized linear module that fits into the FX 16 Graph Mode Quantization workflow 17 activation will be floating point Tensor, we will store floating 18 point weight as well in the module, but in forward we'll quantize 19 and dequantize the weight before running the floating point functional 20 linear operator. 21 """ 22 23 _IS_REFERENCE = True 24 25 def __init__( 26 self, 27 in_features: int, 28 out_features: int, 29 bias_: bool = True, 30 device: Optional[torch.device] = None, 31 dtype: Optional[torch.dtype] = None, 32 weight_qparams: Optional[Dict[str, Any]] = None, 33 ): 34 super().__init__(in_features, out_features, bias_, device, dtype) 35 self._init_weight_qparams(weight_qparams, device) 36 37 def _get_name(self): 38 return "QuantizedLinear(Reference)" 39 40 def forward(self, x: torch.Tensor) -> torch.Tensor: 41 """ 42 we have: 43 w(float) -- quant - dequant \ 44 x(float) ------------- F.linear --- 45 46 In the full model, we will see 47 w(float) -- quant - *dequant \ 48 x -- quant --- *dequant -- *F.linear --- *quant - dequant 49 and the backend should be able to fuse the ops with `*` into a quantized linear 50 """ 51 weight_quant_dequant = self.get_weight() 52 result = F.linear(x, weight_quant_dequant, self.bias) 53 return result 54 55 @classmethod 56 def from_float(cls, float_linear, weight_qparams): 57 qref_linear = Linear( 58 float_linear.in_features, 59 float_linear.out_features, 60 float_linear.bias is not None, 61 device=float_linear.weight.device, 62 dtype=float_linear.weight.dtype, 63 weight_qparams=weight_qparams, 64 ) 65 qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) 66 if float_linear.bias is not None: 67 qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) 68 return qref_linear 69