1# mypy: allow-untyped-defs 2r"""Quantized convolution modules.""" 3 4from typing import List, Optional, TypeVar 5 6import torch 7import torch.ao.nn.intrinsic as nni 8import torch.ao.nn.intrinsic.qat as nniqat 9import torch.nn as nn 10import torch.nn.functional as F 11from torch._ops import ops 12from torch.nn.common_types import _size_1_t 13from torch.nn.modules.utils import _pair, _single, _triple 14from torch.nn.utils import fuse_conv_bn_weights 15 16from .utils import _quantize_weight, WeightedQuantizedModule 17 18 19__all__ = [ 20 "Conv1d", 21 "Conv2d", 22 "Conv3d", 23 "ConvTranspose1d", 24 "ConvTranspose2d", 25 "ConvTranspose3d", 26] 27 28_SUPPORTED_PADDING = {"zeros", "reflect"} 29 30 31def _reverse_repeat_padding(padding: List[int]) -> List[int]: 32 _reversed_padding_repeated_twice: List[int] = [] 33 N = len(padding) 34 for idx in range(N): 35 for _ in range(2): 36 _reversed_padding_repeated_twice.append(padding[N - idx - 1]) 37 return _reversed_padding_repeated_twice 38 39 40class _ConvNd(WeightedQuantizedModule): 41 def __init__( 42 self, 43 in_channels, 44 out_channels, 45 kernel_size, 46 stride=1, 47 padding=0, 48 dilation=1, 49 groups=1, 50 bias=True, 51 padding_mode="zeros", 52 device=None, 53 dtype=None, 54 ): 55 # All subclasses have this signature - See PR #49702s 56 raise NotImplementedError 57 58 def _init( 59 self, 60 in_channels, 61 out_channels, 62 kernel_size, 63 stride, 64 padding, 65 dilation, 66 transposed, 67 output_padding, 68 groups, 69 bias, 70 padding_mode="zeros", 71 device=None, 72 dtype=None, 73 ) -> None: 74 factory_kwargs = {"device": device, "dtype": dtype} 75 super().__init__() 76 77 if in_channels % groups != 0: 78 raise ValueError("in_channels must be divisible by groups") 79 if out_channels % groups != 0: 80 raise ValueError("out_channels must be divisible by groups") 81 self.in_channels = in_channels 82 self.out_channels = out_channels 83 self.kernel_size = kernel_size 84 self.stride = stride 85 self.padding = padding 86 self.dilation = dilation 87 self.transposed = transposed 88 self.output_padding = output_padding 89 self.groups = groups 90 if padding_mode not in _SUPPORTED_PADDING: 91 raise ValueError( 92 f"'padding_mode' {padding_mode} is not supported by quantized convolution" 93 ) 94 self.padding_mode = padding_mode 95 # Initialize as NCHW. set_weight will internally transpose to NHWC. 96 if self.transposed: 97 weight_shape = [in_channels, out_channels // self.groups] 98 else: 99 weight_shape = [out_channels, in_channels // self.groups] 100 qweight = torch._empty_affine_quantized( 101 weight_shape + list(kernel_size), 102 scale=1, 103 zero_point=0, 104 dtype=torch.qint8, 105 **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, 106 ) 107 bias_float = ( 108 torch.zeros( 109 out_channels, 110 dtype=torch.float, 111 **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, 112 ) 113 if bias 114 else None 115 ) 116 117 self.set_weight_bias(qweight, bias_float) 118 self.scale = 1.0 119 self.zero_point = 0 120 121 def set_weight_bias(self, qweight, bias_float): 122 raise NotImplementedError 123 124 def bias(self): 125 raise NotImplementedError 126 127 def _weight_bias(self): 128 raise NotImplementedError 129 130 def extra_repr(self): 131 s = ( 132 "{in_channels}, {out_channels}, kernel_size={kernel_size}" 133 ", stride={stride}, scale={scale}, zero_point={zero_point}" 134 ) 135 if self.padding != (0,) * len(self.padding): 136 s += ", padding={padding}" 137 if self.dilation != (1,) * len(self.dilation): 138 s += ", dilation={dilation}" 139 if self.output_padding != (0,) * len(self.output_padding): 140 s += ", output_padding={output_padding}" 141 if self.groups != 1: 142 s += ", groups={groups}" 143 if self.bias() is None: 144 s += ", bias=False" 145 return s.format(**self.__dict__) 146 147 # ===== Serialization methods ===== 148 # The special consideration here is that we have to unpack the weights into 149 # their regular QTensor form for serialization. Packed weights should not 150 # live outside the process in which they were created, rather they should be 151 # derived from the QTensor weight. 152 # self 153 # |--- weight : Tensor 154 # |--- bias : Tensor 155 # 156 # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed 157 # self 158 # |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase 159 def _save_to_state_dict(self, destination, prefix, keep_vars): 160 super()._save_to_state_dict(destination, prefix, keep_vars) 161 (w, b) = self._weight_bias() 162 destination[prefix + "weight"] = w 163 destination[prefix + "bias"] = b 164 destination[prefix + "scale"] = torch.tensor(self.scale) 165 destination[prefix + "zero_point"] = torch.tensor(self.zero_point) 166 167 @torch.jit.export 168 def __getstate__(self): 169 (w, b) = self._weight_bias() 170 return ( 171 self.in_channels, 172 self.out_channels, 173 self.kernel_size, 174 self.stride, 175 self.padding, 176 self.dilation, 177 self.transposed, 178 self.output_padding, 179 self.groups, 180 self.padding_mode, 181 w, 182 b, 183 self.scale, 184 self.zero_point, 185 self.training, 186 ) 187 188 # ===== Deserialization methods ===== 189 # Counterpart to the serialization methods, we must pack the serialized 190 # QTensor weight into its packed format for use by the FBGEMM ops. 191 def _load_from_state_dict( 192 self, 193 state_dict, 194 prefix, 195 local_metadata, 196 strict, 197 missing_keys, 198 unexpected_keys, 199 error_msgs, 200 ): 201 self.set_weight_bias(state_dict[prefix + "weight"], state_dict[prefix + "bias"]) 202 state_dict.pop(prefix + "weight") 203 state_dict.pop(prefix + "bias") 204 self.scale = float(state_dict[prefix + "scale"]) 205 state_dict.pop(prefix + "scale") 206 self.zero_point = int(state_dict[prefix + "zero_point"]) 207 state_dict.pop(prefix + "zero_point") 208 super()._load_from_state_dict( 209 state_dict, 210 prefix, 211 local_metadata, 212 False, 213 missing_keys, 214 unexpected_keys, 215 error_msgs, 216 ) 217 218 @torch.jit.export 219 def __setstate__(self, state): 220 self.in_channels = state[0] 221 self.out_channels = state[1] 222 self.kernel_size = state[2] 223 self.stride = state[3] 224 self.padding = state[4] 225 self.dilation = state[5] 226 self.transposed = state[6] 227 self.output_padding = state[7] 228 self.groups = state[8] 229 self.padding_mode = state[9] 230 self.set_weight_bias(state[10], state[11]) 231 self.scale = state[12] 232 self.zero_point = state[13] 233 self.training = state[14] 234 235 def __deepcopy__(self, memo): 236 new_instance = type(self).__new__(type(self)) 237 torch.nn.Module.__init__(new_instance) 238 state = self.__getstate__() 239 new_instance.__setstate__(state) 240 return new_instance 241 242 def __copy__(self): 243 return self.__deepcopy__({}) 244 245 @classmethod 246 def get_qconv(cls, mod, activation_post_process, weight_post_process=None): 247 r"""Creates a qconv object and returns it.""" 248 if weight_post_process is None: 249 weight_post_process = mod.qconfig.weight() 250 weight_post_process(mod.weight) 251 assert ( 252 weight_post_process.dtype == torch.qint8 253 ), "Weight observer must have a dtype of qint8" 254 qweight = _quantize_weight(mod.weight.float(), weight_post_process) 255 # the __init__ call used is the one from derived classes and not the one from _ConvNd 256 qconv = cls( 257 mod.in_channels, 258 mod.out_channels, 259 mod.kernel_size, 260 mod.stride, 261 mod.padding, 262 mod.dilation, 263 mod.groups, 264 mod.bias is not None, 265 mod.padding_mode, 266 ) 267 qconv.set_weight_bias(qweight, mod.bias) 268 if ( 269 activation_post_process is None 270 or activation_post_process.dtype == torch.float 271 ): 272 return qconv # dynamic quantization doesn't need scale/zero_point 273 else: 274 act_scale, act_zp = activation_post_process.calculate_qparams() 275 qconv.scale = float(act_scale) 276 qconv.zero_point = int(act_zp) 277 return qconv 278 279 @staticmethod 280 def from_float(cls, mod, use_precomputed_fake_quant=False): 281 if hasattr(mod, "weight_fake_quant"): 282 # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ 283 # ".from_float only works for " + cls.__QAT_MODULE.__name__ 284 if type(mod) == cls._NNIQAT_CONV_BN_MODULE: 285 mod.weight, mod.bias = fuse_conv_bn_weights( 286 mod.weight, 287 mod.bias, 288 mod.bn.running_mean, 289 mod.bn.running_var, 290 mod.bn.eps, 291 mod.bn.weight, 292 mod.bn.bias, 293 ) 294 assert hasattr( 295 mod, "activation_post_process" 296 ), "Input QAT module must have observer attached" 297 weight_post_process = mod.weight_fake_quant 298 activation_post_process = mod.activation_post_process 299 else: 300 assert type(mod) == cls._FLOAT_MODULE, ( 301 " nnq." 302 + cls.__name__ 303 + ".from_float only works for " 304 + cls._FLOAT_MODULE.__name__ 305 + " but got:" 306 + str(type(mod)) 307 ) 308 assert hasattr( 309 mod, "qconfig" 310 ), "Input float module must have qconfig defined." 311 activation_post_process = ( 312 None 313 if not hasattr(mod, "activation_post_process") 314 else mod.activation_post_process 315 ) 316 if type(mod) in [ 317 cls._NNI_CONV_RELU_MODULE, 318 cls._NNI_CONV_ADD_MODULE, 319 cls._NNI_CONV_ADD_RELU_MODULE, 320 ]: 321 mod = mod[0] 322 weight_post_process = mod.qconfig.weight() 323 return cls.get_qconv(mod, activation_post_process, weight_post_process) 324 325 @classmethod 326 def from_reference(cls, ref_qconv, output_scale, output_zero_point): 327 r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module 328 Args: 329 ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization 330 utilities or provided by the user 331 output_scale (float): scale for output Tensor 332 output_zero_point (int): zero point for output Tensor 333 """ 334 qconv = cls( 335 ref_qconv.in_channels, 336 ref_qconv.out_channels, 337 ref_qconv.kernel_size, # type: ignore[arg-type] 338 ref_qconv.stride, # type: ignore[arg-type] 339 ref_qconv.padding, # type: ignore[arg-type] 340 ref_qconv.dilation, # type: ignore[arg-type] 341 ref_qconv.groups, 342 ref_qconv.bias is not None, # type: ignore[arg-type] 343 ref_qconv.padding_mode, 344 device=ref_qconv.weight.device, 345 dtype=ref_qconv.weight.dtype, 346 ) 347 qweight = ref_qconv.get_quantized_weight() 348 qconv.set_weight_bias(qweight, ref_qconv.bias) 349 qconv.scale = float(output_scale) 350 qconv.zero_point = int(output_zero_point) 351 return qconv 352 353 354class Conv1d(_ConvNd): 355 r"""Applies a 1D convolution over a quantized input signal composed of 356 several quantized input planes. 357 358 For details on input arguments, parameters, and implementation see 359 :class:`~torch.nn.Conv1d`. 360 361 .. note:: 362 Only `zeros` is supported for the :attr:`padding_mode` argument. 363 364 .. note:: 365 Only `torch.quint8` is supported for the input data type. 366 367 368 Attributes: 369 weight (Tensor): packed tensor derived from the learnable weight 370 parameter. 371 scale (Tensor): scalar for the output scale 372 zero_point (Tensor): scalar for the output zero point 373 374 See :class:`~torch.nn.Conv1d` for other attributes. 375 376 Examples:: 377 378 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) 379 >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2) 380 >>> input = torch.randn(20, 16, 100) 381 >>> # quantize input to quint8 382 >>> # xdoctest: +SKIP 383 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, 384 ... dtype=torch.quint8) 385 >>> output = m(q_input) 386 387 """ 388 389 _FLOAT_MODULE = nn.Conv1d 390 _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d 391 _NNI_CONV_RELU_MODULE = nni.ConvReLU1d 392 _NNI_CONV_ADD_MODULE: None = None 393 _NNI_CONV_ADD_RELU_MODULE: None = None 394 395 def __init__( 396 self, 397 in_channels: int, 398 out_channels: int, 399 kernel_size: _size_1_t, 400 stride: _size_1_t = 1, 401 padding: _size_1_t = 0, 402 dilation: _size_1_t = 1, 403 groups: int = 1, 404 bias: bool = True, 405 padding_mode: str = "zeros", 406 device=None, 407 dtype=None, 408 ): 409 factory_kwargs = {"device": device, "dtype": dtype} 410 kernel_size = _single(kernel_size) 411 stride = _single(stride) 412 padding = padding if isinstance(padding, str) else _single(padding) 413 dilation = _single(dilation) 414 415 # Subclasses of _ConvNd needs to call _init rather than __init__. See 416 # discussion on PR #49702 417 super()._init( 418 in_channels, 419 out_channels, 420 kernel_size, 421 stride, 422 padding, 423 dilation, 424 False, 425 _single(0), 426 groups, 427 bias, 428 padding_mode, 429 **factory_kwargs, 430 ) 431 432 def _get_name(self): 433 return "QuantizedConv1d" 434 435 def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: 436 if self.padding_mode == "zeros": 437 self._packed_params = torch.ops.quantized.conv1d_prepack( 438 w, b, self.stride, self.padding, self.dilation, self.groups 439 ) 440 else: 441 self._packed_params = torch.ops.quantized.conv1d_prepack( 442 w, b, self.stride, _pair(0), self.dilation, self.groups 443 ) 444 445 def _weight_bias(self): 446 w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) 447 return w, b 448 449 def weight(self): 450 return self._weight_bias()[0] 451 452 def bias(self): 453 return self._weight_bias()[1] 454 455 def forward(self, input): 456 # Temporarily using len(shape) instead of ndim due to JIT issue 457 # https://github.com/pytorch/pytorch/issues/23890 458 if len(input.shape) != 3: 459 raise ValueError("Input shape must be `(N, C, L)`!") 460 if self.padding_mode != "zeros": 461 # Padding in Conv1d is stored as (p, p), need to get (p,) 462 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) 463 input = F.pad( 464 input, _reversed_padding_repeated_twice, mode=self.padding_mode 465 ) 466 return ops.quantized.conv1d( 467 input, self._packed_params, self.scale, self.zero_point 468 ) 469 470 @classmethod 471 def from_float(cls, mod, use_precomputed_fake_quant=False): 472 r"""Creates a quantized module from a float module or qparams_dict. 473 474 Args: 475 mod (Module): a float module, either produced by torch.ao.quantization 476 utilities or provided by the user 477 """ 478 return _ConvNd.from_float( 479 cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant 480 ) 481 482 483class Conv2d(_ConvNd): 484 r"""Applies a 2D convolution over a quantized input signal composed of 485 several quantized input planes. 486 487 For details on input arguments, parameters, and implementation see 488 :class:`~torch.nn.Conv2d`. 489 490 .. note:: 491 Only `zeros` is supported for the :attr:`padding_mode` argument. 492 493 .. note:: 494 Only `torch.quint8` is supported for the input data type. 495 496 497 Attributes: 498 weight (Tensor): packed tensor derived from the learnable weight 499 parameter. 500 scale (Tensor): scalar for the output scale 501 zero_point (Tensor): scalar for the output zero point 502 503 See :class:`~torch.nn.Conv2d` for other attributes. 504 505 Examples:: 506 507 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) 508 >>> # With square kernels and equal stride 509 >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2) 510 >>> # non-square kernels and unequal stride and with padding 511 >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) 512 >>> # non-square kernels and unequal stride and with padding and dilation 513 >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) 514 >>> input = torch.randn(20, 16, 50, 100) 515 >>> # quantize input to quint8 516 >>> # xdoctest: +SKIP 517 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 518 >>> output = m(q_input) 519 520 """ 521 _FLOAT_MODULE = nn.Conv2d 522 _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d 523 _NNI_CONV_RELU_MODULE = nni.ConvReLU2d 524 _NNI_CONV_ADD_MODULE = nni.ConvAdd2d 525 _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d 526 527 def __init__( 528 self, 529 in_channels, 530 out_channels, 531 kernel_size, 532 stride=1, 533 padding=0, 534 dilation=1, 535 groups=1, 536 bias=True, 537 padding_mode="zeros", 538 device=None, 539 dtype=None, 540 ): 541 factory_kwargs = {"device": device, "dtype": dtype} 542 kernel_size = _pair(kernel_size) 543 stride = _pair(stride) 544 padding = _pair(padding) 545 dilation = _pair(dilation) 546 # Subclasses of _ConvNd need to call _init rather than __init__. See 547 # discussion on PR #49702 548 super()._init( 549 in_channels, 550 out_channels, 551 kernel_size, 552 stride, 553 padding, 554 dilation, 555 False, 556 _pair(0), 557 groups, 558 bias, 559 padding_mode, 560 **factory_kwargs, 561 ) 562 563 def _get_name(self): 564 return "QuantizedConv2d" 565 566 def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: 567 if self.padding_mode == "zeros": 568 self._packed_params = torch.ops.quantized.conv2d_prepack( 569 w, b, self.stride, self.padding, self.dilation, self.groups 570 ) 571 else: 572 self._packed_params = torch.ops.quantized.conv2d_prepack( 573 w, b, self.stride, _pair(0), self.dilation, self.groups 574 ) 575 576 def _weight_bias(self): 577 return self._packed_params.unpack() 578 579 def weight(self): 580 return self._weight_bias()[0] 581 582 def bias(self): 583 return self._weight_bias()[1] 584 585 def forward(self, input): 586 # Temporarily using len(shape) instead of ndim due to JIT issue 587 # https://github.com/pytorch/pytorch/issues/23890 588 if len(input.shape) != 4: 589 raise ValueError("Input shape must be `(N, C, H, W)`!") 590 if self.padding_mode != "zeros": 591 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) 592 input = F.pad( 593 input, _reversed_padding_repeated_twice, mode=self.padding_mode 594 ) 595 return ops.quantized.conv2d( 596 input, self._packed_params, self.scale, self.zero_point 597 ) 598 599 @classmethod 600 def from_float(cls, mod, use_precomputed_fake_quant=False): 601 r"""Creates a quantized module from a float module or qparams_dict. 602 603 Args: 604 mod (Module): a float module, either produced by torch.ao.quantization 605 utilities or provided by the user 606 """ 607 return _ConvNd.from_float( 608 cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant 609 ) 610 611 612class Conv3d(_ConvNd): 613 r"""Applies a 3D convolution over a quantized input signal composed of 614 several quantized input planes. 615 616 For details on input arguments, parameters, and implementation see 617 :class:`~torch.nn.Conv3d`. 618 619 .. note:: 620 Only `zeros` is supported for the :attr:`padding_mode` argument. 621 622 .. note:: 623 Only `torch.quint8` is supported for the input data type. 624 625 626 Attributes: 627 weight (Tensor): packed tensor derived from the learnable weight 628 parameter. 629 scale (Tensor): scalar for the output scale 630 zero_point (Tensor): scalar for the output zero point 631 632 See :class:`~torch.nn.Conv3d` for other attributes. 633 634 Examples:: 635 636 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) 637 >>> # With square kernels and equal stride 638 >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2) 639 >>> # non-square kernels and unequal stride and with padding 640 >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) 641 >>> # non-square kernels and unequal stride and with padding and dilation 642 >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) 643 >>> input = torch.randn(20, 16, 56, 56, 56) 644 >>> # quantize input to quint8 645 >>> # xdoctest: +SKIP 646 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 647 >>> output = m(q_input) 648 649 """ 650 _FLOAT_MODULE = nn.Conv3d 651 _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d 652 _NNI_CONV_RELU_MODULE = nni.ConvReLU3d 653 _NNI_CONV_ADD_MODULE: None = None 654 _NNI_CONV_ADD_RELU_MODULE: None = None 655 656 def __init__( 657 self, 658 in_channels, 659 out_channels, 660 kernel_size, 661 stride=1, 662 padding=0, 663 dilation=1, 664 groups=1, 665 bias=True, 666 padding_mode="zeros", 667 device=None, 668 dtype=None, 669 ): 670 assert padding_mode != "reflect", "Conv3d does not support reflection padding" 671 factory_kwargs = {"device": device, "dtype": dtype} 672 kernel_size = _triple(kernel_size) 673 stride = _triple(stride) 674 padding = _triple(padding) 675 dilation = _triple(dilation) 676 # Subclasses of _ConvNd need to call _init rather than __init__. See 677 # discussion on PR #49702 678 super()._init( 679 in_channels, 680 out_channels, 681 kernel_size, 682 stride, 683 padding, 684 dilation, 685 False, 686 _triple(0), 687 groups, 688 bias, 689 padding_mode, 690 **factory_kwargs, 691 ) 692 693 def _get_name(self): 694 return "QuantizedConv3d" 695 696 def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: 697 if self.padding_mode == "zeros": 698 self._packed_params = torch.ops.quantized.conv3d_prepack( 699 w, b, self.stride, self.padding, self.dilation, self.groups 700 ) 701 else: 702 self._packed_params = torch.ops.quantized.conv3d_prepack( 703 w, b, self.stride, _triple(0), self.dilation, self.groups 704 ) 705 706 def _weight_bias(self): 707 return self._packed_params.unpack() 708 709 def weight(self): 710 return self._weight_bias()[0] 711 712 def bias(self): 713 return self._weight_bias()[1] 714 715 def forward(self, input): 716 # Temporarily using len(shape) instead of ndim due to JIT issue 717 # https://github.com/pytorch/pytorch/issues/23890 718 if len(input.shape) != 5: 719 raise ValueError("Input shape must be `(N, C, D, H, W)`!") 720 if self.padding_mode != "zeros": 721 _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) 722 input = F.pad( 723 input, _reversed_padding_repeated_twice, mode=self.padding_mode 724 ) 725 return ops.quantized.conv3d( 726 input, self._packed_params, self.scale, self.zero_point 727 ) 728 729 @classmethod 730 def from_float(cls, mod, use_precomputed_fake_quant=False): 731 r"""Creates a quantized module from a float module or qparams_dict. 732 733 Args: 734 mod (Module): a float module, either produced by torch.ao.quantization 735 utilities or provided by the user 736 """ 737 return _ConvNd.from_float( 738 cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant 739 ) 740 741 742# === Transposed Convolutions === 743MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) 744 745 746class _ConvTransposeNd(_ConvNd): 747 _FLOAT_MODULE = MOD 748 749 def __init__( 750 self, 751 in_channels, 752 out_channels, 753 kernel_size, 754 stride, 755 padding, 756 dilation, 757 transposed, 758 output_padding, 759 groups, 760 bias, 761 padding_mode, 762 device=None, 763 dtype=None, 764 ): 765 if padding_mode != "zeros": 766 raise ValueError( 767 f'Only "zeros" padding mode is supported for {self.__class__.__name__}' 768 ) 769 factory_kwargs = {"device": device, "dtype": dtype} 770 # Subclasses of _ConvNd need to call _init rather than __init__. See 771 # discussion on PR #49702 772 super()._init( 773 in_channels, 774 out_channels, 775 kernel_size, 776 stride, 777 padding, 778 dilation, 779 transposed, 780 output_padding, 781 groups, 782 bias, 783 padding_mode, 784 **factory_kwargs, 785 ) 786 787 def _input_padding( 788 self, kernel_size: List[int], dilation: List[int], padding: List[int] 789 ) -> List[int]: 790 res = torch.jit.annotate(List[int], []) 791 for kdx in range(len(kernel_size)): 792 pad = dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx] 793 res.append(pad) 794 return res 795 796 @classmethod 797 def from_float(cls, mod, use_precomputed_fake_quant=False): 798 r"""Creates a quantized module from a float module or qparams_dict. 799 Args: 800 mod (Module): a float module, either produced by torch.ao.quantization 801 utilities or provided by the user 802 """ 803 # derived classes override cls._FLOAT_MODULE attribute 804 msg = ( 805 " nnq." 806 + cls.__name__ 807 + ".from_float only works for " 808 + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] 809 ) 810 assert type(mod) == cls._FLOAT_MODULE, msg 811 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined." 812 weight_post_process = mod.qconfig.weight() 813 weight_post_process(mod.weight) 814 assert ( 815 weight_post_process.dtype == torch.qint8 816 ), "Weight observer must have a dtype of qint8" 817 qweight = _quantize_weight(mod.weight.float(), weight_post_process) 818 # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd 819 qconv = cls( 820 mod.in_channels, 821 mod.out_channels, 822 mod.kernel_size, # type: ignore[call-arg] 823 mod.stride, 824 mod.padding, 825 mod.output_padding, 826 mod.groups, 827 mod.bias is not None, 828 mod.dilation, 829 mod.padding_mode, 830 ) 831 qconv.set_weight_bias(qweight, mod.bias) 832 if ( 833 not hasattr(mod, "activation_post_process") 834 or mod.activation_post_process.dtype == torch.float 835 ): 836 return qconv # dynamic quantization doesn't need scale/zero_point 837 else: 838 act_scale, act_zp = mod.activation_post_process.calculate_qparams() 839 qconv.scale = float(act_scale) 840 qconv.zero_point = int(act_zp) 841 return qconv 842 843 @staticmethod 844 def from_reference(cls, ref_qconvt, output_scale, output_zero_point): 845 r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module 846 Args: 847 ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization 848 utilities or provided by the user 849 output_scale (float): scale for output Tensor 850 output_zero_point (int): zero point for output Tensor 851 """ 852 qconv = cls( 853 ref_qconvt.in_channels, 854 ref_qconvt.out_channels, 855 ref_qconvt.kernel_size, # type: ignore[arg-type] 856 ref_qconvt.stride, # type: ignore[arg-type] 857 ref_qconvt.padding, # type: ignore[arg-type] 858 ref_qconvt.output_padding, # type: ignore[arg-type] 859 ref_qconvt.groups, 860 ref_qconvt.bias is not None, # type: ignore[arg-type] 861 ref_qconvt.dilation, # type: ignore[arg-type] 862 ref_qconvt.padding_mode, 863 device=ref_qconvt.weight.device, 864 dtype=ref_qconvt.weight.dtype, 865 ) 866 qweight = ref_qconvt.get_quantized_weight() 867 qconv.set_weight_bias(qweight, ref_qconvt.bias) 868 qconv.scale = float(output_scale) 869 qconv.zero_point = int(output_zero_point) 870 return qconv 871 872 873class ConvTranspose1d(_ConvTransposeNd): 874 r"""Applies a 1D transposed convolution operator over an input image 875 composed of several input planes. 876 For details on input arguments, parameters, and implementation see 877 :class:`~torch.nn.ConvTranspose1d`. 878 879 .. note:: Currently only the QNNPACK engine is implemented. 880 Please, set the `torch.backends.quantized.engine = 'qnnpack'` 881 882 For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d` 883 884 Attributes: 885 weight (Tensor): packed tensor derived from the learnable weight 886 parameter. 887 scale (Tensor): scalar for the output scale 888 zero_point (Tensor): scalar for the output zero point 889 See :class:`~torch.nn.ConvTranspose2d` for other attributes. 890 891 Examples:: 892 893 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) 894 >>> torch.backends.quantized.engine = 'qnnpack' 895 >>> from torch.ao.nn import quantized as nnq 896 >>> # With square kernels and equal stride 897 >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2) 898 >>> # non-square kernels and unequal stride and with padding 899 >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) 900 >>> input = torch.randn(20, 16, 50) 901 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 902 >>> output = m(q_input) 903 >>> # exact output size can be also specified as an argument 904 >>> input = torch.randn(1, 16, 12) 905 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 906 >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1) 907 >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) 908 >>> h = downsample(q_input) 909 >>> h.size() 910 torch.Size([1, 16, 6]) 911 >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) 912 >>> output = upsample(h, output_size=input.size()) 913 >>> output.size() 914 torch.Size([1, 16, 12]) 915 """ 916 917 _FLOAT_MODULE = nn.ConvTranspose1d 918 919 def __init__( 920 self, 921 in_channels, 922 out_channels, 923 kernel_size, 924 stride=1, 925 padding=0, 926 output_padding=0, 927 groups=1, 928 bias=True, 929 dilation=1, 930 padding_mode="zeros", 931 device=None, 932 dtype=None, 933 ): 934 factory_kwargs = {"device": device, "dtype": dtype} 935 kernel_size = _single(kernel_size) 936 stride = _single(stride) 937 padding = _single(padding) 938 dilation = _single(dilation) 939 output_padding = _single(output_padding) 940 941 super().__init__( 942 in_channels, 943 out_channels, 944 kernel_size, 945 stride, 946 padding, 947 dilation, 948 True, 949 output_padding, 950 groups, 951 bias, 952 padding_mode, 953 **factory_kwargs, 954 ) 955 956 def _get_name(self): 957 return "QuantizedConvTranspose1d" 958 959 def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: 960 self._packed_params = torch.ops.quantized.conv_transpose1d_prepack( 961 w, 962 b, 963 self.stride, 964 self.padding, 965 self.output_padding, 966 self.dilation, 967 self.groups, 968 ) 969 970 def _weight_bias(self): 971 w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params) 972 return w, b 973 974 def weight(self): 975 (w, _) = self._weight_bias() 976 return w 977 978 def bias(self): 979 (_, b) = self._weight_bias() 980 return b 981 982 def forward(self, input): 983 # Temporarily using len(shape) instead of ndim due to JIT issue 984 # https://github.com/pytorch/pytorch/issues/23890 985 if len(input.shape) != 3: 986 raise ValueError("Input shape must be `(N, C, L)`!") 987 return torch.ops.quantized.conv_transpose1d( 988 input, self._packed_params, self.scale, self.zero_point 989 ) 990 991 @classmethod 992 def from_reference(cls, ref_qconvt, output_scale, output_zero_point): 993 return _ConvTransposeNd.from_reference( 994 cls, ref_qconvt, output_scale, output_zero_point 995 ) 996 997 998class ConvTranspose2d(_ConvTransposeNd): 999 r"""Applies a 2D transposed convolution operator over an input image 1000 composed of several input planes. 1001 For details on input arguments, parameters, and implementation see 1002 :class:`~torch.nn.ConvTranspose2d`. 1003 1004 For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d` 1005 1006 Attributes: 1007 weight (Tensor): packed tensor derived from the learnable weight 1008 parameter. 1009 scale (Tensor): scalar for the output scale 1010 zero_point (Tensor): scalar for the output zero point 1011 See :class:`~torch.nn.ConvTranspose2d` for other attributes. 1012 1013 Examples:: 1014 1015 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) 1016 >>> # QNNPACK or FBGEMM as backend 1017 >>> torch.backends.quantized.engine = 'qnnpack' 1018 >>> # With square kernels and equal stride 1019 >>> import torch.ao.nn.quantized as nnq 1020 >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) 1021 >>> # non-square kernels and unequal stride and with padding 1022 >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) 1023 >>> input = torch.randn(20, 16, 50, 100) 1024 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 1025 >>> output = m(q_input) 1026 >>> # exact output size can be also specified as an argument 1027 >>> input = torch.randn(1, 16, 12, 12) 1028 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 1029 >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) 1030 >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) 1031 >>> h = downsample(q_input) 1032 >>> h.size() 1033 torch.Size([1, 16, 6, 6]) 1034 >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) 1035 >>> output = upsample(h, output_size=input.size()) 1036 >>> output.size() 1037 torch.Size([1, 16, 12, 12]) 1038 """ 1039 1040 _FLOAT_MODULE = nn.ConvTranspose2d 1041 1042 def __init__( 1043 self, 1044 in_channels, 1045 out_channels, 1046 kernel_size, 1047 stride=1, 1048 padding=0, 1049 output_padding=0, 1050 groups=1, 1051 bias=True, 1052 dilation=1, 1053 padding_mode="zeros", 1054 device=None, 1055 dtype=None, 1056 ): 1057 factory_kwargs = {"device": device, "dtype": dtype} 1058 kernel_size = _pair(kernel_size) 1059 stride = _pair(stride) 1060 padding = _pair(padding) 1061 dilation = _pair(dilation) 1062 output_padding = _pair(output_padding) 1063 1064 super().__init__( 1065 in_channels, 1066 out_channels, 1067 kernel_size, 1068 stride, 1069 padding, 1070 dilation, 1071 True, 1072 output_padding, 1073 groups, 1074 bias, 1075 padding_mode, 1076 **factory_kwargs, 1077 ) 1078 1079 def _get_name(self): 1080 return "QuantizedConvTranspose2d" 1081 1082 def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: 1083 self._packed_params = torch.ops.quantized.conv_transpose2d_prepack( 1084 w, 1085 b, 1086 self.stride, 1087 self.padding, 1088 self.output_padding, 1089 self.dilation, 1090 self.groups, 1091 ) 1092 1093 def _weight_bias(self): 1094 w, b = torch.ops.quantized.conv2d_unpack(self._packed_params) 1095 return w, b 1096 1097 def weight(self): 1098 (w, _) = self._weight_bias() 1099 return w 1100 1101 def bias(self): 1102 (_, b) = self._weight_bias() 1103 return b 1104 1105 def forward(self, input): 1106 # Temporarily using len(shape) instead of ndim due to JIT issue 1107 # https://github.com/pytorch/pytorch/issues/23890 1108 if len(input.shape) != 4: 1109 raise ValueError("Input shape must be `(N, C, H, W)`!") 1110 return ops.quantized.conv_transpose2d( 1111 input, self._packed_params, self.scale, self.zero_point 1112 ) 1113 1114 @classmethod 1115 def from_reference(cls, ref_qconvt, output_scale, output_zero_point): 1116 return _ConvTransposeNd.from_reference( 1117 cls, ref_qconvt, output_scale, output_zero_point 1118 ) 1119 1120 1121class ConvTranspose3d(_ConvTransposeNd): 1122 r"""Applies a 3D transposed convolution operator over an input image 1123 composed of several input planes. 1124 For details on input arguments, parameters, and implementation see 1125 :class:`~torch.nn.ConvTranspose3d`. 1126 1127 .. note:: Currently only the FBGEMM engine is implemented. 1128 Please, set the `torch.backends.quantized.engine = 'fbgemm'` 1129 1130 For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d` 1131 1132 Attributes: 1133 weight (Tensor): packed tensor derived from the learnable weight 1134 parameter. 1135 scale (Tensor): scalar for the output scale 1136 zero_point (Tensor): scalar for the output zero point 1137 See :class:`~torch.nn.ConvTranspose3d` for other attributes. 1138 1139 Examples:: 1140 1141 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) 1142 >>> torch.backends.quantized.engine = 'fbgemm' 1143 >>> from torch.ao.nn import quantized as nnq 1144 >>> # With cubic kernels and equal stride 1145 >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) 1146 >>> # non-cubic kernels and unequal stride and with padding 1147 >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) 1148 >>> input = torch.randn(20, 16, 50, 100, 100) 1149 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 1150 >>> output = m(q_input) 1151 >>> # exact output size can be also specified as an argument 1152 >>> input = torch.randn(1, 16, 12, 12, 12) 1153 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) 1154 >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) 1155 >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) 1156 >>> h = downsample(q_input) 1157 >>> h.size() 1158 torch.Size([1, 16, 6, 6, 6]) 1159 >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) 1160 >>> output = upsample(h, output_size=input.size()) 1161 >>> output.size() 1162 torch.Size([1, 16, 12, 12, 12]) 1163 """ 1164 1165 _FLOAT_MODULE = nn.ConvTranspose3d 1166 1167 def __init__( 1168 self, 1169 in_channels, 1170 out_channels, 1171 kernel_size, 1172 stride=1, 1173 padding=0, 1174 output_padding=0, 1175 groups=1, 1176 bias=True, 1177 dilation=1, 1178 padding_mode="zeros", 1179 device=None, 1180 dtype=None, 1181 ): 1182 factory_kwargs = {"device": device, "dtype": dtype} 1183 kernel_size = _triple(kernel_size) 1184 stride = _triple(stride) 1185 padding = _triple(padding) 1186 dilation = _triple(dilation) 1187 output_padding = _triple(output_padding) 1188 1189 super().__init__( 1190 in_channels, 1191 out_channels, 1192 kernel_size, 1193 stride, 1194 padding, 1195 dilation, 1196 True, 1197 output_padding, 1198 groups, 1199 bias, 1200 padding_mode, 1201 **factory_kwargs, 1202 ) 1203 1204 def _get_name(self): 1205 return "QuantizedConvTranspose3d" 1206 1207 def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: 1208 self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( 1209 w, 1210 b, 1211 self.stride, 1212 self.padding, 1213 self.output_padding, 1214 self.dilation, 1215 self.groups, 1216 ) 1217 1218 def _weight_bias(self): 1219 w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) 1220 return w, b 1221 1222 def weight(self): 1223 (w, _) = self._weight_bias() 1224 return w 1225 1226 def bias(self): 1227 (_, b) = self._weight_bias() 1228 return b 1229 1230 def forward(self, input): 1231 # Temporarily using len(shape) instead of ndim due to JIT issue 1232 # https://github.com/pytorch/pytorch/issues/23890 1233 if len(input.shape) != 5: 1234 raise ValueError("Input shape must be `(N, C, T, H, W)`!") 1235 return ops.quantized.conv_transpose3d( 1236 input, self._packed_params, self.scale, self.zero_point 1237 ) 1238 1239 @classmethod 1240 def from_reference(cls, ref_qconvt, output_scale, output_zero_point): 1241 return _ConvTransposeNd.from_reference( 1242 cls, ref_qconvt, output_scale, output_zero_point 1243 ) 1244