1# mypy: allow-untyped-defs 2from typing import Optional 3 4import torch 5import torch.ao.nn.intrinsic as nni 6from torch.ao.nn.quantized.modules.utils import ( 7 _hide_packed_params_repr, 8 _quantize_weight, 9) 10from torch.ao.nn.sparse.quantized import linear 11from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern 12 13 14__all__ = ["Linear"] 15 16 17class Linear(torch.nn.Module): 18 r""" 19 A dynamically quantized sparse linear module with float tensor as inputs and outputs. 20 """ 21 _version = 1 22 _op_type = "sparse_dynamic" 23 _FLOAT_MODULE = torch.nn.Linear 24 25 def __init__( 26 self, 27 in_features, 28 out_features, 29 row_block_size, 30 col_block_size, 31 bias=True, 32 dtype=torch.qint8, 33 ): 34 super().__init__() 35 36 if dtype != torch.qint8: 37 raise NotImplementedError( 38 "Only QINT8 is supported for Sparse Quantized Linear Dynamic" 39 ) 40 41 self.in_features = in_features 42 self.out_features = out_features 43 44 if bias: 45 bias = torch.zeros(self.out_features, dtype=torch.float) 46 else: 47 bias = None 48 49 qweight = torch._empty_affine_quantized( 50 [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 51 ) 52 self._packed_params = linear.LinearPackedParams( 53 row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype 54 ) 55 self._packed_params.set_weight_bias( 56 qweight, bias, row_block_size, col_block_size 57 ) 58 59 def _get_name(self): 60 return "SparseQuantizedDynamicLinear" 61 62 def extra_repr(self): 63 return f"in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}" 64 65 def __repr__(self): 66 return _hide_packed_params_repr(self, linear.LinearPackedParams) 67 68 def forward(self, x: torch.Tensor) -> torch.Tensor: 69 return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params) 70 71 def _save_to_state_dict(self, destination, prefix, keep_vars): 72 super()._save_to_state_dict(destination, prefix, keep_vars) 73 destination[prefix + "op_type"] = self._op_type 74 75 def _load_from_state_dict( 76 self, 77 state_dict, 78 prefix, 79 local_metadata, 80 strict, 81 missing_keys, 82 unexpected_keys, 83 error_msgs, 84 ): 85 op_type = int(state_dict[prefix + "op_type"]) 86 assert ( 87 op_type == "sparse" 88 ), f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]" 89 state_dict.pop(prefix + "op_type") 90 91 version = local_metadata.get("version", None) 92 assert version <= self._version 93 94 # Is this code valid? In old quantization it seemed to be used to load 95 # older model 96 weight = state_dict.pop(prefix + "weight") 97 bias = state_dict.pop(prefix + "bias") 98 state_dict.update( 99 { 100 prefix + "_packed_params.weight": weight, 101 prefix + "_packed_params.bias": bias, 102 } 103 ) 104 105 super()._load_from_state_dict( 106 state_dict, 107 prefix, 108 local_metadata, 109 False, 110 missing_keys, 111 unexpected_keys, 112 error_msgs, 113 ) 114 115 def _weight_bias(self): 116 return self._packed_params._weight_bias() 117 118 def weight(self): 119 return self._weight_bias()[0] 120 121 def bias(self): 122 return self._weight_bias()[1] 123 124 def set_weight_bias( 125 self, 126 w: torch.Tensor, 127 b: Optional[torch.Tensor], 128 row_block_size: Optional[int], 129 col_block_size: Optional[int], 130 ) -> None: 131 assert row_block_size is not None and col_block_size is not None 132 self.out_features = w.shape[0] 133 self.in_features = w.shape[1] 134 self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) 135 136 @classmethod 137 def from_float(cls, mod, use_precomputed_fake_quant=False): 138 r"""Create a quantized sparse dynamic module from a float module. 139 140 We only care about the convert at this stage, no need for observers just yet. 141 """ 142 assert type(mod) == cls._FLOAT_MODULE, ( 143 " nnq." 144 + cls.__name__ 145 + ".from_float only works for " 146 + cls._FLOAT_MODULE.__name__ 147 ) 148 # TODO: Need to add options to qconfig to avoid the calibration. 149 # TODO: Add calibration for the sparsity 150 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" 151 if type(mod) == nni.LinearReLU: 152 mod = mod[0] 153 if mod.qconfig is not None and mod.qconfig.weight is not None: 154 weight_observer = mod.qconfig.weight() 155 else: 156 # We have the circular import issues if we import the qconfig in the beginning of this file: 157 # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the 158 # import until we need it. 159 from torch.ao.quantization.qconfig import default_dynamic_qconfig 160 161 weight_observer = default_dynamic_qconfig.weight() 162 163 # It is important to multiply by the mask BEFORE calling the `weight_observer` 164 # TODO (zaf): Mask might not be part of the qconfig (T83295194) 165 weight = mod.weight 166 if getattr(mod.qconfig, "mask", False): 167 weight = mod.qconfig.mask * mod.weight 168 169 weight_observer(weight) 170 dtype = weight_observer.dtype 171 assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" 172 w_sc, w_zp = weight_observer.calculate_qparams() 173 if isinstance(w_zp, torch.Tensor): 174 assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" 175 else: 176 assert w_zp == 0, "Weight zero point must map to 0" 177 qweight = _quantize_weight(weight.float(), weight_observer) 178 179 row_block_size, col_block_size = LinearBlockSparsePattern.block_size() 180 qlinear = cls( 181 mod.in_features, 182 mod.out_features, 183 row_block_size, 184 col_block_size, 185 dtype=dtype, 186 ) 187 qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) 188 return qlinear 189