1# mypy: allow-untyped-defs 2from typing import List 3 4import torch 5from torch import Tensor 6from torch._ops import ops 7 8 9__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"] 10 11 12class FloatFunctional(torch.nn.Module): 13 r"""State collector class for float operations. 14 15 The instance of this class can be used instead of the ``torch.`` prefix for 16 some operations. See example usage below. 17 18 .. note:: 19 20 This class does not provide a ``forward`` hook. Instead, you must use 21 one of the underlying functions (e.g. ``add``). 22 23 Examples:: 24 25 >>> f_add = FloatFunctional() 26 >>> a = torch.tensor(3.0) 27 >>> b = torch.tensor(4.0) 28 >>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)`` 29 30 Valid operation names: 31 - add 32 - cat 33 - mul 34 - add_relu 35 - add_scalar 36 - mul_scalar 37 """ 38 39 def __init__(self) -> None: 40 super().__init__() 41 self.activation_post_process = torch.nn.Identity() 42 43 def forward(self, x): 44 raise RuntimeError( 45 "FloatFunctional is not intended to use the " 46 + "'forward'. Please use the underlying operation" 47 ) 48 49 r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" 50 51 def add(self, x: Tensor, y: Tensor) -> Tensor: 52 r = torch.add(x, y) 53 r = self.activation_post_process(r) 54 return r 55 56 r"""Operation equivalent to ``torch.add(Tensor, float)``""" 57 58 def add_scalar(self, x: Tensor, y: float) -> Tensor: 59 r = torch.add(x, y) 60 # Note: this operation is not observed because the observation is not 61 # needed for the quantized op. 62 return r 63 64 r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" 65 66 def mul(self, x: Tensor, y: Tensor) -> Tensor: 67 r = torch.mul(x, y) 68 r = self.activation_post_process(r) 69 return r 70 71 r"""Operation equivalent to ``torch.mul(Tensor, float)``""" 72 73 def mul_scalar(self, x: Tensor, y: float) -> Tensor: 74 r = torch.mul(x, y) 75 # Note: this operation is not observed because the observation is not 76 # needed for the quantized op. 77 return r 78 79 r"""Operation equivalent to ``torch.cat``""" 80 81 def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: 82 r = torch.cat(x, dim=dim) 83 r = self.activation_post_process(r) 84 return r 85 86 r"""Operation equivalent to ``relu(torch.add(x,y))``""" 87 88 def add_relu(self, x: Tensor, y: Tensor) -> Tensor: 89 r = torch.add(x, y) 90 r = torch.nn.functional.relu(r) 91 r = self.activation_post_process(r) 92 return r 93 94 r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" 95 96 def matmul(self, x: Tensor, y: Tensor) -> Tensor: 97 r = torch.matmul(x, y) 98 r = self.activation_post_process(r) 99 return r 100 101 102class FXFloatFunctional(torch.nn.Module): 103 r"""module to replace FloatFunctional module before FX graph mode quantization, 104 since activation_post_process will be inserted in top level module directly 105 106 Valid operation names: 107 - add 108 - cat 109 - mul 110 - add_relu 111 - add_scalar 112 - mul_scalar 113 """ 114 115 def forward(self, x): 116 raise RuntimeError( 117 "FloatFunctional is not intended to use the " 118 + "'forward'. Please use the underlying operation" 119 ) 120 121 r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" 122 123 def add(self, x: Tensor, y: Tensor) -> Tensor: 124 r = torch.add(x, y) 125 return r 126 127 r"""Operation equivalent to ``torch.add(Tensor, float)``""" 128 129 def add_scalar(self, x: Tensor, y: float) -> Tensor: 130 r = torch.add(x, y) 131 return r 132 133 r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" 134 135 def mul(self, x: Tensor, y: Tensor) -> Tensor: 136 r = torch.mul(x, y) 137 return r 138 139 r"""Operation equivalent to ``torch.mul(Tensor, float)``""" 140 141 def mul_scalar(self, x: Tensor, y: float) -> Tensor: 142 r = torch.mul(x, y) 143 return r 144 145 r"""Operation equivalent to ``torch.cat``""" 146 147 def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: 148 r = torch.cat(x, dim=dim) 149 return r 150 151 r"""Operation equivalent to ``relu(torch.add(x,y))``""" 152 153 def add_relu(self, x: Tensor, y: Tensor) -> Tensor: 154 r = torch.add(x, y) 155 r = torch.nn.functional.relu(r) 156 return r 157 158 r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" 159 160 def matmul(self, x: Tensor, y: Tensor) -> Tensor: 161 r = torch.matmul(x, y) 162 return r 163 164 165class QFunctional(torch.nn.Module): 166 r"""Wrapper class for quantized operations. 167 168 The instance of this class can be used instead of the 169 ``torch.ops.quantized`` prefix. See example usage below. 170 171 .. note:: 172 173 This class does not provide a ``forward`` hook. Instead, you must use 174 one of the underlying functions (e.g. ``add``). 175 176 Examples:: 177 178 >>> q_add = QFunctional() 179 >>> # xdoctest: +SKIP 180 >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32) 181 >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32) 182 >>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)`` 183 184 Valid operation names: 185 - add 186 - cat 187 - mul 188 - add_relu 189 - add_scalar 190 - mul_scalar 191 """ 192 193 def __init__(self) -> None: 194 super().__init__() 195 self.scale = 1.0 196 self.zero_point = 0 197 self.activation_post_process = torch.nn.Identity() 198 199 def _save_to_state_dict(self, destination, prefix, keep_vars): 200 super()._save_to_state_dict(destination, prefix, keep_vars) 201 destination[prefix + "scale"] = torch.tensor(self.scale) 202 destination[prefix + "zero_point"] = torch.tensor(self.zero_point) 203 204 def _load_from_state_dict( 205 self, 206 state_dict, 207 prefix, 208 local_metadata, 209 strict, 210 missing_keys, 211 unexpected_keys, 212 error_msgs, 213 ): 214 self.scale = float(state_dict.pop(prefix + "scale")) 215 self.zero_point = int(state_dict.pop(prefix + "zero_point")) 216 super()._load_from_state_dict( 217 state_dict, 218 prefix, 219 local_metadata, 220 False, 221 missing_keys, 222 unexpected_keys, 223 error_msgs, 224 ) 225 226 def _get_name(self): 227 return "QFunctional" 228 229 def extra_repr(self): 230 return f"scale={self.scale}, zero_point={self.zero_point}" 231 232 def forward(self, x): 233 raise RuntimeError( 234 "Functional is not intended to use the " 235 + "'forward'. Please use the underlying operation" 236 ) 237 238 r"""Operation equivalent to ``torch.ops.quantized.add``""" 239 240 def add(self, x: Tensor, y: Tensor) -> Tensor: 241 r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point) 242 r = self.activation_post_process(r) 243 return r 244 245 r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``""" 246 247 def add_scalar(self, x: Tensor, y: float) -> Tensor: 248 r = ops.quantized.add_scalar(x, y) 249 # Note: this operation is not observed because the observation is not 250 # needed for the quantized op. 251 return r 252 253 r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``""" 254 255 def mul(self, x: Tensor, y: Tensor) -> Tensor: 256 r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point) 257 r = self.activation_post_process(r) 258 return r 259 260 r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``""" 261 262 def mul_scalar(self, x: Tensor, y: float) -> Tensor: 263 r = ops.quantized.mul_scalar(x, y) 264 # Note: this operation is not observed because the observation is not 265 # needed for the quantized op. 266 return r 267 268 r"""Operation equivalent to ``torch.ops.quantized.cat``""" 269 270 def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: 271 r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim) 272 r = self.activation_post_process(r) 273 return r 274 275 r"""Operation equivalent to ``torch.ops.quantized.add_relu``""" 276 277 def add_relu(self, x: Tensor, y: Tensor) -> Tensor: 278 r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point) 279 r = self.activation_post_process(r) 280 return r 281 282 r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``""" 283 284 def matmul(self, x: Tensor, y: Tensor) -> Tensor: 285 r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point) 286 # Note: this operation is not observed because the observation is not 287 # needed for the quantized op. 288 return r 289 290 @classmethod 291 def from_float(cls, mod, use_precomputed_fake_quant=False): 292 assert ( 293 type(mod) == FloatFunctional 294 ), "QFunctional.from_float expects an instance of FloatFunctional" 295 scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] 296 new_mod = QFunctional() 297 new_mod.scale = float(scale) 298 new_mod.zero_point = int(zero_point) 299 return new_mod 300