1# mypy: allow-untyped-defs 2from typing import Any, Dict, List, Optional 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7from torch.nn.common_types import _size_1_t 8 9from .utils import ReferenceQuantizedModule 10 11 12__all__ = [ 13 "Conv1d", 14 "Conv2d", 15 "Conv3d", 16 "ConvTranspose1d", 17 "ConvTranspose2d", 18 "ConvTranspose3d", 19] 20 21 22class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule): 23 """A reference version of nn.quantized.Conv2d 24 we will not pack the parameters in this module, since weight packing is an 25 optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), 26 this is useful when user want to use this module in other backends like Glow. 27 """ 28 29 __annotations__ = {"bias": Optional[torch.Tensor]} 30 _IS_REFERENCE = True 31 32 @staticmethod 33 def from_float(cls, float_conv, weight_qparams): 34 qref_conv = cls( 35 float_conv.in_channels, 36 float_conv.out_channels, 37 float_conv.kernel_size, # type: ignore[arg-type] 38 float_conv.stride, # type: ignore[arg-type] 39 float_conv.padding, # type: ignore[arg-type] 40 float_conv.dilation, # type: ignore[arg-type] 41 float_conv.groups, 42 float_conv.bias is not None, # type: ignore[arg-type] 43 float_conv.padding_mode, 44 device=float_conv.weight.device, 45 dtype=float_conv.weight.dtype, 46 weight_qparams=weight_qparams, 47 ) 48 qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) 49 if float_conv.bias is not None: 50 qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) 51 return qref_conv 52 53 54class Conv1d(_ConvNd, nn.Conv1d): 55 def __init__( 56 self, 57 in_channels: int, 58 out_channels: int, 59 kernel_size: _size_1_t, 60 stride: _size_1_t = 1, 61 padding: _size_1_t = 0, 62 dilation: _size_1_t = 1, 63 groups: int = 1, 64 bias: bool = True, 65 padding_mode: str = "zeros", 66 device=None, 67 dtype=None, 68 weight_qparams: Optional[Dict[str, Any]] = None, 69 ): 70 nn.Conv1d.__init__( 71 self, 72 in_channels, 73 out_channels, 74 kernel_size, 75 stride, 76 padding, 77 dilation, 78 groups, 79 bias, 80 padding_mode, 81 device, 82 dtype, 83 ) 84 self._init_weight_qparams(weight_qparams, device) 85 86 def forward(self, x: torch.Tensor) -> torch.Tensor: 87 """ 88 we have: 89 w(float) -- quant - dequant \ 90 x(float) ------------- F.conv1d --- 91 92 In the full model, we will see 93 w(float) -- quant - *dequant \ 94 x -- quant --- *dequant -- *F.conv1d --- *quant - dequant 95 and the backend should be able to fuse the ops with `*` into a quantized conv1d 96 """ 97 weight_quant_dequant = self.get_weight() 98 result = F.conv1d( 99 x, 100 weight_quant_dequant, 101 self.bias, 102 self.stride, 103 self.padding, 104 self.dilation, 105 self.groups, 106 ) 107 return result 108 109 def _get_name(self): 110 return "QuantizedConv1d(Reference)" 111 112 @classmethod 113 def from_float(cls, float_conv, weight_qparams): 114 return _ConvNd.from_float(cls, float_conv, weight_qparams) 115 116 117class Conv2d(_ConvNd, nn.Conv2d): 118 def __init__( 119 self, 120 in_channels, 121 out_channels, 122 kernel_size, 123 stride=1, 124 padding=0, 125 dilation=1, 126 groups=1, 127 bias=True, 128 padding_mode="zeros", 129 device=None, 130 dtype=None, 131 weight_qparams: Optional[Dict[str, Any]] = None, 132 ): 133 nn.Conv2d.__init__( 134 self, 135 in_channels, 136 out_channels, 137 kernel_size, 138 stride, 139 padding, 140 dilation, 141 groups, 142 bias, 143 padding_mode, 144 device, 145 dtype, 146 ) 147 self._init_weight_qparams(weight_qparams, device) 148 149 def forward(self, x: torch.Tensor) -> torch.Tensor: 150 """ 151 we have: 152 w(float) -- quant - dequant \ 153 x(float) ------------- F.conv2d --- 154 155 In the full model, we will see 156 w(float) -- quant - *dequant \ 157 x -- quant --- *dequant -- *F.conv2d --- *quant - dequant 158 and the backend should be able to fuse the ops with `*` into a quantized conv2d 159 """ 160 weight_quant_dequant = self.get_weight() 161 result = F.conv2d( 162 x, 163 weight_quant_dequant, 164 self.bias, 165 self.stride, 166 self.padding, 167 self.dilation, 168 self.groups, 169 ) 170 return result 171 172 def _get_name(self): 173 return "QuantizedConv2d(Reference)" 174 175 @classmethod 176 def from_float(cls, float_conv, weight_qparams): 177 return _ConvNd.from_float(cls, float_conv, weight_qparams) 178 179 180class Conv3d(_ConvNd, nn.Conv3d): 181 def __init__( 182 self, 183 in_channels, 184 out_channels, 185 kernel_size, 186 stride=1, 187 padding=0, 188 dilation=1, 189 groups=1, 190 bias=True, 191 padding_mode="zeros", 192 device=None, 193 dtype=None, 194 weight_qparams: Optional[Dict[str, Any]] = None, 195 ): 196 nn.Conv3d.__init__( 197 self, 198 in_channels, 199 out_channels, 200 kernel_size, 201 stride, 202 padding, 203 dilation, 204 groups, 205 bias, 206 padding_mode, 207 device, 208 dtype, 209 ) 210 self._init_weight_qparams(weight_qparams, device) 211 212 def forward(self, x: torch.Tensor) -> torch.Tensor: 213 """ 214 we have: 215 w(float) -- quant - dequant \ 216 x(float) ------------- F.conv3d --- 217 218 In the full model, we will see 219 w(float) -- quant - *dequant \ 220 x -- quant --- *dequant -- *F.conv3d --- *quant - dequant 221 and the backend should be able to fuse the ops with `*` into a quantized conv3d 222 """ 223 weight_quant_dequant = self.get_weight() 224 result = F.conv3d( 225 x, 226 weight_quant_dequant, 227 self.bias, 228 self.stride, 229 self.padding, 230 self.dilation, 231 self.groups, 232 ) 233 return result 234 235 def _get_name(self): 236 return "QuantizedConv3d(Reference)" 237 238 @classmethod 239 def from_float(cls, float_conv, weight_qparams): 240 return _ConvNd.from_float(cls, float_conv, weight_qparams) 241 242 243class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd): 244 """A reference version of nn.quantized.ConvTranspose2d 245 we will not pack the parameters in this module, since weight packing is an 246 optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), 247 this is useful when user want to use this module in other backends like Glow. 248 """ 249 250 @staticmethod 251 def from_float(cls, float_conv, weight_qparams): 252 qref_conv = cls( 253 float_conv.in_channels, 254 float_conv.out_channels, 255 float_conv.kernel_size, # type: ignore[arg-type] 256 float_conv.stride, # type: ignore[arg-type] 257 float_conv.padding, # type: ignore[arg-type] 258 float_conv.output_padding, # type: ignore[arg-type] 259 float_conv.groups, 260 float_conv.bias is not None, # type: ignore[arg-type] 261 float_conv.dilation, # type: ignore[arg-type] 262 float_conv.padding_mode, 263 device=float_conv.weight.device, 264 dtype=float_conv.weight.dtype, 265 weight_qparams=weight_qparams, 266 ) 267 qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) 268 if float_conv.bias is not None: 269 qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) 270 return qref_conv 271 272 273class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d): 274 def __init__( 275 self, 276 in_channels: int, 277 out_channels: int, 278 kernel_size: _size_1_t, 279 stride: _size_1_t = 1, 280 padding: _size_1_t = 0, 281 output_padding: _size_1_t = 0, 282 groups: int = 1, 283 bias: bool = True, 284 dilation: _size_1_t = 1, 285 padding_mode: str = "zeros", 286 device=None, 287 dtype=None, 288 weight_qparams: Optional[Dict[str, Any]] = None, 289 ): 290 nn.ConvTranspose1d.__init__( 291 self, 292 in_channels, 293 out_channels, 294 kernel_size, 295 stride, 296 padding, 297 output_padding, 298 groups, 299 bias, 300 dilation, 301 padding_mode, 302 device, 303 dtype, 304 ) 305 self._init_weight_qparams(weight_qparams, device) 306 307 def forward( 308 self, x: torch.Tensor, output_size: Optional[List[int]] = None 309 ) -> torch.Tensor: 310 """ 311 we have: 312 w(float) -- quant - dequant \ 313 x(float) ------------- F.convTranspose1d --- 314 In the full model, we will see 315 w(float) -- quant - *dequant \ 316 x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant 317 and the backend should be able to fuse the ops with `*` into a quantized conv1d 318 """ 319 320 assert isinstance(self.padding, tuple) 321 # One cannot replace List by Tuple or Sequence in "_output_padding" because 322 # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. 323 output_padding = self._output_padding( 324 input, # type: ignore[arg-type] 325 output_size, 326 self.stride, # type: ignore[arg-type] 327 self.padding, # type: ignore[arg-type] 328 self.kernel_size, # type: ignore[arg-type] 329 self.dilation, # type: ignore[arg-type] 330 ) 331 332 weight_quant_dequant = self.get_weight() 333 result = F.conv_transpose1d( 334 x, 335 weight_quant_dequant, 336 self.bias, 337 self.stride, 338 self.padding, 339 output_padding, 340 self.groups, 341 self.dilation, 342 ) 343 return result 344 345 def _get_name(self): 346 return "QuantizedConvTranspose1d(Reference)" 347 348 @classmethod 349 def from_float(cls, float_conv, weight_qparams): 350 return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) 351 352 353class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d): 354 def __init__( 355 self, 356 in_channels, 357 out_channels, 358 kernel_size, 359 stride=1, 360 padding=0, 361 output_padding=0, 362 groups=1, 363 bias=True, 364 dilation=1, 365 padding_mode="zeros", 366 device=None, 367 dtype=None, 368 weight_qparams: Optional[Dict[str, Any]] = None, 369 ): 370 nn.ConvTranspose2d.__init__( 371 self, 372 in_channels, 373 out_channels, 374 kernel_size, 375 stride, 376 padding, 377 output_padding, 378 groups, 379 bias, 380 dilation, 381 padding_mode, 382 device, 383 dtype, 384 ) 385 self._init_weight_qparams(weight_qparams, device) 386 387 def forward( 388 self, x: torch.Tensor, output_size: Optional[List[int]] = None 389 ) -> torch.Tensor: 390 """ 391 we have: 392 w(float) -- quant - dequant \ 393 x(float) ------------- F.convTranspose2d --- 394 In the full model, we will see 395 w(float) -- quant - *dequant \ 396 x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant 397 and the backend should be able to fuse the ops with `*` into a quantized conv2d 398 """ 399 assert isinstance(self.padding, tuple) 400 # One cannot replace List by Tuple or Sequence in "_output_padding" because 401 # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. 402 403 output_padding = self._output_padding( 404 input, # type: ignore[arg-type] 405 output_size, 406 self.stride, # type: ignore[arg-type] 407 self.padding, # type: ignore[arg-type] 408 self.kernel_size, # type: ignore[arg-type] 409 self.dilation, # type: ignore[arg-type] 410 ) 411 412 weight_quant_dequant = self.get_weight() 413 result = F.conv_transpose2d( 414 x, 415 weight_quant_dequant, 416 self.bias, 417 self.stride, 418 self.padding, 419 output_padding, 420 self.groups, 421 self.dilation, 422 ) 423 424 return result 425 426 def _get_name(self): 427 return "QuantizedConvTranspose2d(Reference)" 428 429 @classmethod 430 def from_float(cls, float_conv, weight_qparams): 431 return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) 432 433 434class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d): 435 def __init__( 436 self, 437 in_channels, 438 out_channels, 439 kernel_size, 440 stride=1, 441 padding=0, 442 output_padding=0, 443 groups=1, 444 bias=True, 445 dilation=1, 446 padding_mode="zeros", 447 device=None, 448 dtype=None, 449 weight_qparams: Optional[Dict[str, Any]] = None, 450 ): 451 nn.ConvTranspose3d.__init__( 452 self, 453 in_channels, 454 out_channels, 455 kernel_size, 456 stride, 457 padding, 458 output_padding, 459 groups, 460 bias, 461 dilation, 462 padding_mode, 463 device, 464 dtype, 465 ) 466 self._init_weight_qparams(weight_qparams, device) 467 468 def forward( 469 self, x: torch.Tensor, output_size: Optional[List[int]] = None 470 ) -> torch.Tensor: 471 """ 472 we have: 473 w(float) -- quant - dequant \ 474 x(float) ------------- F.convTranspose3d --- 475 In the full model, we will see 476 w(float) -- quant - *dequant \ 477 x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant 478 and the backend should be able to fuse the ops with `*` into a quantized conv3d 479 """ 480 481 assert isinstance(self.padding, tuple) 482 # One cannot replace List by Tuple or Sequence in "_output_padding" because 483 # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. 484 output_padding = self._output_padding( 485 input, # type: ignore[arg-type] 486 output_size, 487 self.stride, # type: ignore[arg-type] 488 self.padding, # type: ignore[arg-type] 489 self.kernel_size, # type: ignore[arg-type] 490 self.dilation, # type: ignore[arg-type] 491 ) 492 493 weight_quant_dequant = self.get_weight() 494 result = F.conv_transpose3d( 495 x, 496 weight_quant_dequant, 497 self.bias, 498 self.stride, 499 self.padding, 500 output_padding, 501 self.groups, 502 self.dilation, 503 ) 504 return result 505 506 def _get_name(self): 507 return "QuantizedConvTranspose3d(Reference)" 508 509 @classmethod 510 def from_float(cls, float_conv, weight_qparams): 511 return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) 512