xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Quantized convolution modules."""
3
4from typing import List, Optional, TypeVar
5
6import torch
7import torch.ao.nn.intrinsic as nni
8import torch.ao.nn.intrinsic.qat as nniqat
9import torch.nn as nn
10import torch.nn.functional as F
11from torch._ops import ops
12from torch.nn.common_types import _size_1_t
13from torch.nn.modules.utils import _pair, _single, _triple
14from torch.nn.utils import fuse_conv_bn_weights
15
16from .utils import _quantize_weight, WeightedQuantizedModule
17
18
19__all__ = [
20    "Conv1d",
21    "Conv2d",
22    "Conv3d",
23    "ConvTranspose1d",
24    "ConvTranspose2d",
25    "ConvTranspose3d",
26]
27
28_SUPPORTED_PADDING = {"zeros", "reflect"}
29
30
31def _reverse_repeat_padding(padding: List[int]) -> List[int]:
32    _reversed_padding_repeated_twice: List[int] = []
33    N = len(padding)
34    for idx in range(N):
35        for _ in range(2):
36            _reversed_padding_repeated_twice.append(padding[N - idx - 1])
37    return _reversed_padding_repeated_twice
38
39
40class _ConvNd(WeightedQuantizedModule):
41    def __init__(
42        self,
43        in_channels,
44        out_channels,
45        kernel_size,
46        stride=1,
47        padding=0,
48        dilation=1,
49        groups=1,
50        bias=True,
51        padding_mode="zeros",
52        device=None,
53        dtype=None,
54    ):
55        # All subclasses have this signature - See PR #49702s
56        raise NotImplementedError
57
58    def _init(
59        self,
60        in_channels,
61        out_channels,
62        kernel_size,
63        stride,
64        padding,
65        dilation,
66        transposed,
67        output_padding,
68        groups,
69        bias,
70        padding_mode="zeros",
71        device=None,
72        dtype=None,
73    ) -> None:
74        factory_kwargs = {"device": device, "dtype": dtype}
75        super().__init__()
76
77        if in_channels % groups != 0:
78            raise ValueError("in_channels must be divisible by groups")
79        if out_channels % groups != 0:
80            raise ValueError("out_channels must be divisible by groups")
81        self.in_channels = in_channels
82        self.out_channels = out_channels
83        self.kernel_size = kernel_size
84        self.stride = stride
85        self.padding = padding
86        self.dilation = dilation
87        self.transposed = transposed
88        self.output_padding = output_padding
89        self.groups = groups
90        if padding_mode not in _SUPPORTED_PADDING:
91            raise ValueError(
92                f"'padding_mode' {padding_mode} is not supported by quantized convolution"
93            )
94        self.padding_mode = padding_mode
95        # Initialize as NCHW. set_weight will internally transpose to NHWC.
96        if self.transposed:
97            weight_shape = [in_channels, out_channels // self.groups]
98        else:
99            weight_shape = [out_channels, in_channels // self.groups]
100        qweight = torch._empty_affine_quantized(
101            weight_shape + list(kernel_size),
102            scale=1,
103            zero_point=0,
104            dtype=torch.qint8,
105            **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
106        )
107        bias_float = (
108            torch.zeros(
109                out_channels,
110                dtype=torch.float,
111                **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
112            )
113            if bias
114            else None
115        )
116
117        self.set_weight_bias(qweight, bias_float)
118        self.scale = 1.0
119        self.zero_point = 0
120
121    def set_weight_bias(self, qweight, bias_float):
122        raise NotImplementedError
123
124    def bias(self):
125        raise NotImplementedError
126
127    def _weight_bias(self):
128        raise NotImplementedError
129
130    def extra_repr(self):
131        s = (
132            "{in_channels}, {out_channels}, kernel_size={kernel_size}"
133            ", stride={stride}, scale={scale}, zero_point={zero_point}"
134        )
135        if self.padding != (0,) * len(self.padding):
136            s += ", padding={padding}"
137        if self.dilation != (1,) * len(self.dilation):
138            s += ", dilation={dilation}"
139        if self.output_padding != (0,) * len(self.output_padding):
140            s += ", output_padding={output_padding}"
141        if self.groups != 1:
142            s += ", groups={groups}"
143        if self.bias() is None:
144            s += ", bias=False"
145        return s.format(**self.__dict__)
146
147    # ===== Serialization methods =====
148    # The special consideration here is that we have to unpack the weights into
149    # their regular QTensor form for serialization. Packed weights should not
150    # live outside the process in which they were created, rather they should be
151    # derived from the QTensor weight.
152    #   self
153    #   |--- weight : Tensor
154    #   |--- bias : Tensor
155    #
156    # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed
157    #   self
158    #   |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase
159    def _save_to_state_dict(self, destination, prefix, keep_vars):
160        super()._save_to_state_dict(destination, prefix, keep_vars)
161        (w, b) = self._weight_bias()
162        destination[prefix + "weight"] = w
163        destination[prefix + "bias"] = b
164        destination[prefix + "scale"] = torch.tensor(self.scale)
165        destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
166
167    @torch.jit.export
168    def __getstate__(self):
169        (w, b) = self._weight_bias()
170        return (
171            self.in_channels,
172            self.out_channels,
173            self.kernel_size,
174            self.stride,
175            self.padding,
176            self.dilation,
177            self.transposed,
178            self.output_padding,
179            self.groups,
180            self.padding_mode,
181            w,
182            b,
183            self.scale,
184            self.zero_point,
185            self.training,
186        )
187
188    # ===== Deserialization methods =====
189    # Counterpart to the serialization methods, we must pack the serialized
190    # QTensor weight into its packed format for use by the FBGEMM ops.
191    def _load_from_state_dict(
192        self,
193        state_dict,
194        prefix,
195        local_metadata,
196        strict,
197        missing_keys,
198        unexpected_keys,
199        error_msgs,
200    ):
201        self.set_weight_bias(state_dict[prefix + "weight"], state_dict[prefix + "bias"])
202        state_dict.pop(prefix + "weight")
203        state_dict.pop(prefix + "bias")
204        self.scale = float(state_dict[prefix + "scale"])
205        state_dict.pop(prefix + "scale")
206        self.zero_point = int(state_dict[prefix + "zero_point"])
207        state_dict.pop(prefix + "zero_point")
208        super()._load_from_state_dict(
209            state_dict,
210            prefix,
211            local_metadata,
212            False,
213            missing_keys,
214            unexpected_keys,
215            error_msgs,
216        )
217
218    @torch.jit.export
219    def __setstate__(self, state):
220        self.in_channels = state[0]
221        self.out_channels = state[1]
222        self.kernel_size = state[2]
223        self.stride = state[3]
224        self.padding = state[4]
225        self.dilation = state[5]
226        self.transposed = state[6]
227        self.output_padding = state[7]
228        self.groups = state[8]
229        self.padding_mode = state[9]
230        self.set_weight_bias(state[10], state[11])
231        self.scale = state[12]
232        self.zero_point = state[13]
233        self.training = state[14]
234
235    def __deepcopy__(self, memo):
236        new_instance = type(self).__new__(type(self))
237        torch.nn.Module.__init__(new_instance)
238        state = self.__getstate__()
239        new_instance.__setstate__(state)
240        return new_instance
241
242    def __copy__(self):
243        return self.__deepcopy__({})
244
245    @classmethod
246    def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
247        r"""Creates a qconv object and returns it."""
248        if weight_post_process is None:
249            weight_post_process = mod.qconfig.weight()
250        weight_post_process(mod.weight)
251        assert (
252            weight_post_process.dtype == torch.qint8
253        ), "Weight observer must have a dtype of qint8"
254        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
255        # the __init__ call used is the one from derived classes and not the one from _ConvNd
256        qconv = cls(
257            mod.in_channels,
258            mod.out_channels,
259            mod.kernel_size,
260            mod.stride,
261            mod.padding,
262            mod.dilation,
263            mod.groups,
264            mod.bias is not None,
265            mod.padding_mode,
266        )
267        qconv.set_weight_bias(qweight, mod.bias)
268        if (
269            activation_post_process is None
270            or activation_post_process.dtype == torch.float
271        ):
272            return qconv  # dynamic quantization doesn't need scale/zero_point
273        else:
274            act_scale, act_zp = activation_post_process.calculate_qparams()
275            qconv.scale = float(act_scale)
276            qconv.zero_point = int(act_zp)
277            return qconv
278
279    @staticmethod
280    def from_float(cls, mod, use_precomputed_fake_quant=False):
281        if hasattr(mod, "weight_fake_quant"):
282            # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
283            # ".from_float only works for " + cls.__QAT_MODULE.__name__
284            if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
285                mod.weight, mod.bias = fuse_conv_bn_weights(
286                    mod.weight,
287                    mod.bias,
288                    mod.bn.running_mean,
289                    mod.bn.running_var,
290                    mod.bn.eps,
291                    mod.bn.weight,
292                    mod.bn.bias,
293                )
294            assert hasattr(
295                mod, "activation_post_process"
296            ), "Input QAT module must have observer attached"
297            weight_post_process = mod.weight_fake_quant
298            activation_post_process = mod.activation_post_process
299        else:
300            assert type(mod) == cls._FLOAT_MODULE, (
301                " nnq."
302                + cls.__name__
303                + ".from_float only works for "
304                + cls._FLOAT_MODULE.__name__
305                + " but got:"
306                + str(type(mod))
307            )
308            assert hasattr(
309                mod, "qconfig"
310            ), "Input float module must have qconfig defined."
311            activation_post_process = (
312                None
313                if not hasattr(mod, "activation_post_process")
314                else mod.activation_post_process
315            )
316            if type(mod) in [
317                cls._NNI_CONV_RELU_MODULE,
318                cls._NNI_CONV_ADD_MODULE,
319                cls._NNI_CONV_ADD_RELU_MODULE,
320            ]:
321                mod = mod[0]
322            weight_post_process = mod.qconfig.weight()
323        return cls.get_qconv(mod, activation_post_process, weight_post_process)
324
325    @classmethod
326    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
327        r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
328        Args:
329            ref_qconv (Module): a reference quantized  module, either produced by torch.ao.quantization
330                                utilities or provided by the user
331            output_scale (float): scale for output Tensor
332            output_zero_point (int): zero point for output Tensor
333        """
334        qconv = cls(
335            ref_qconv.in_channels,
336            ref_qconv.out_channels,
337            ref_qconv.kernel_size,  # type: ignore[arg-type]
338            ref_qconv.stride,  # type: ignore[arg-type]
339            ref_qconv.padding,  # type: ignore[arg-type]
340            ref_qconv.dilation,  # type: ignore[arg-type]
341            ref_qconv.groups,
342            ref_qconv.bias is not None,  # type: ignore[arg-type]
343            ref_qconv.padding_mode,
344            device=ref_qconv.weight.device,
345            dtype=ref_qconv.weight.dtype,
346        )
347        qweight = ref_qconv.get_quantized_weight()
348        qconv.set_weight_bias(qweight, ref_qconv.bias)
349        qconv.scale = float(output_scale)
350        qconv.zero_point = int(output_zero_point)
351        return qconv
352
353
354class Conv1d(_ConvNd):
355    r"""Applies a 1D convolution over a quantized input signal composed of
356    several quantized input planes.
357
358    For details on input arguments, parameters, and implementation see
359    :class:`~torch.nn.Conv1d`.
360
361    .. note::
362        Only `zeros` is supported for the :attr:`padding_mode` argument.
363
364    .. note::
365        Only `torch.quint8` is supported for the input data type.
366
367
368    Attributes:
369        weight (Tensor):     packed tensor derived from the learnable weight
370                             parameter.
371        scale (Tensor):      scalar for the output scale
372        zero_point (Tensor): scalar for the output zero point
373
374    See :class:`~torch.nn.Conv1d` for other attributes.
375
376    Examples::
377
378        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
379        >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
380        >>> input = torch.randn(20, 16, 100)
381        >>> # quantize input to quint8
382        >>> # xdoctest: +SKIP
383        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
384        ...                                     dtype=torch.quint8)
385        >>> output = m(q_input)
386
387    """
388
389    _FLOAT_MODULE = nn.Conv1d
390    _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
391    _NNI_CONV_RELU_MODULE = nni.ConvReLU1d
392    _NNI_CONV_ADD_MODULE: None = None
393    _NNI_CONV_ADD_RELU_MODULE: None = None
394
395    def __init__(
396        self,
397        in_channels: int,
398        out_channels: int,
399        kernel_size: _size_1_t,
400        stride: _size_1_t = 1,
401        padding: _size_1_t = 0,
402        dilation: _size_1_t = 1,
403        groups: int = 1,
404        bias: bool = True,
405        padding_mode: str = "zeros",
406        device=None,
407        dtype=None,
408    ):
409        factory_kwargs = {"device": device, "dtype": dtype}
410        kernel_size = _single(kernel_size)
411        stride = _single(stride)
412        padding = padding if isinstance(padding, str) else _single(padding)
413        dilation = _single(dilation)
414
415        # Subclasses of _ConvNd needs to call _init rather than __init__. See
416        # discussion on PR #49702
417        super()._init(
418            in_channels,
419            out_channels,
420            kernel_size,
421            stride,
422            padding,
423            dilation,
424            False,
425            _single(0),
426            groups,
427            bias,
428            padding_mode,
429            **factory_kwargs,
430        )
431
432    def _get_name(self):
433        return "QuantizedConv1d"
434
435    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
436        if self.padding_mode == "zeros":
437            self._packed_params = torch.ops.quantized.conv1d_prepack(
438                w, b, self.stride, self.padding, self.dilation, self.groups
439            )
440        else:
441            self._packed_params = torch.ops.quantized.conv1d_prepack(
442                w, b, self.stride, _pair(0), self.dilation, self.groups
443            )
444
445    def _weight_bias(self):
446        w, b = torch.ops.quantized.conv1d_unpack(self._packed_params)
447        return w, b
448
449    def weight(self):
450        return self._weight_bias()[0]
451
452    def bias(self):
453        return self._weight_bias()[1]
454
455    def forward(self, input):
456        # Temporarily using len(shape) instead of ndim due to JIT issue
457        # https://github.com/pytorch/pytorch/issues/23890
458        if len(input.shape) != 3:
459            raise ValueError("Input shape must be `(N, C, L)`!")
460        if self.padding_mode != "zeros":
461            # Padding in Conv1d is stored as (p, p), need to get (p,)
462            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
463            input = F.pad(
464                input, _reversed_padding_repeated_twice, mode=self.padding_mode
465            )
466        return ops.quantized.conv1d(
467            input, self._packed_params, self.scale, self.zero_point
468        )
469
470    @classmethod
471    def from_float(cls, mod, use_precomputed_fake_quant=False):
472        r"""Creates a quantized module from a float module or qparams_dict.
473
474        Args:
475            mod (Module): a float module, either produced by torch.ao.quantization
476              utilities or provided by the user
477        """
478        return _ConvNd.from_float(
479            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
480        )
481
482
483class Conv2d(_ConvNd):
484    r"""Applies a 2D convolution over a quantized input signal composed of
485    several quantized input planes.
486
487    For details on input arguments, parameters, and implementation see
488    :class:`~torch.nn.Conv2d`.
489
490    .. note::
491        Only `zeros` is supported for the :attr:`padding_mode` argument.
492
493    .. note::
494        Only `torch.quint8` is supported for the input data type.
495
496
497    Attributes:
498        weight (Tensor):     packed tensor derived from the learnable weight
499                             parameter.
500        scale (Tensor):      scalar for the output scale
501        zero_point (Tensor): scalar for the output zero point
502
503    See :class:`~torch.nn.Conv2d` for other attributes.
504
505    Examples::
506
507        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
508        >>> # With square kernels and equal stride
509        >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
510        >>> # non-square kernels and unequal stride and with padding
511        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
512        >>> # non-square kernels and unequal stride and with padding and dilation
513        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
514        >>> input = torch.randn(20, 16, 50, 100)
515        >>> # quantize input to quint8
516        >>> # xdoctest: +SKIP
517        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
518        >>> output = m(q_input)
519
520    """
521    _FLOAT_MODULE = nn.Conv2d
522    _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d
523    _NNI_CONV_RELU_MODULE = nni.ConvReLU2d
524    _NNI_CONV_ADD_MODULE = nni.ConvAdd2d
525    _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d
526
527    def __init__(
528        self,
529        in_channels,
530        out_channels,
531        kernel_size,
532        stride=1,
533        padding=0,
534        dilation=1,
535        groups=1,
536        bias=True,
537        padding_mode="zeros",
538        device=None,
539        dtype=None,
540    ):
541        factory_kwargs = {"device": device, "dtype": dtype}
542        kernel_size = _pair(kernel_size)
543        stride = _pair(stride)
544        padding = _pair(padding)
545        dilation = _pair(dilation)
546        # Subclasses of _ConvNd need to call _init rather than __init__. See
547        # discussion on PR #49702
548        super()._init(
549            in_channels,
550            out_channels,
551            kernel_size,
552            stride,
553            padding,
554            dilation,
555            False,
556            _pair(0),
557            groups,
558            bias,
559            padding_mode,
560            **factory_kwargs,
561        )
562
563    def _get_name(self):
564        return "QuantizedConv2d"
565
566    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
567        if self.padding_mode == "zeros":
568            self._packed_params = torch.ops.quantized.conv2d_prepack(
569                w, b, self.stride, self.padding, self.dilation, self.groups
570            )
571        else:
572            self._packed_params = torch.ops.quantized.conv2d_prepack(
573                w, b, self.stride, _pair(0), self.dilation, self.groups
574            )
575
576    def _weight_bias(self):
577        return self._packed_params.unpack()
578
579    def weight(self):
580        return self._weight_bias()[0]
581
582    def bias(self):
583        return self._weight_bias()[1]
584
585    def forward(self, input):
586        # Temporarily using len(shape) instead of ndim due to JIT issue
587        # https://github.com/pytorch/pytorch/issues/23890
588        if len(input.shape) != 4:
589            raise ValueError("Input shape must be `(N, C, H, W)`!")
590        if self.padding_mode != "zeros":
591            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
592            input = F.pad(
593                input, _reversed_padding_repeated_twice, mode=self.padding_mode
594            )
595        return ops.quantized.conv2d(
596            input, self._packed_params, self.scale, self.zero_point
597        )
598
599    @classmethod
600    def from_float(cls, mod, use_precomputed_fake_quant=False):
601        r"""Creates a quantized module from a float module or qparams_dict.
602
603        Args:
604            mod (Module): a float module, either produced by torch.ao.quantization
605              utilities or provided by the user
606        """
607        return _ConvNd.from_float(
608            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
609        )
610
611
612class Conv3d(_ConvNd):
613    r"""Applies a 3D convolution over a quantized input signal composed of
614    several quantized input planes.
615
616    For details on input arguments, parameters, and implementation see
617    :class:`~torch.nn.Conv3d`.
618
619    .. note::
620        Only `zeros` is supported for the :attr:`padding_mode` argument.
621
622    .. note::
623        Only `torch.quint8` is supported for the input data type.
624
625
626    Attributes:
627        weight (Tensor):     packed tensor derived from the learnable weight
628                             parameter.
629        scale (Tensor):      scalar for the output scale
630        zero_point (Tensor): scalar for the output zero point
631
632    See :class:`~torch.nn.Conv3d` for other attributes.
633
634    Examples::
635
636        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
637        >>> # With square kernels and equal stride
638        >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
639        >>> # non-square kernels and unequal stride and with padding
640        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
641        >>> # non-square kernels and unequal stride and with padding and dilation
642        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
643        >>> input = torch.randn(20, 16, 56, 56, 56)
644        >>> # quantize input to quint8
645        >>> # xdoctest: +SKIP
646        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
647        >>> output = m(q_input)
648
649    """
650    _FLOAT_MODULE = nn.Conv3d
651    _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d
652    _NNI_CONV_RELU_MODULE = nni.ConvReLU3d
653    _NNI_CONV_ADD_MODULE: None = None
654    _NNI_CONV_ADD_RELU_MODULE: None = None
655
656    def __init__(
657        self,
658        in_channels,
659        out_channels,
660        kernel_size,
661        stride=1,
662        padding=0,
663        dilation=1,
664        groups=1,
665        bias=True,
666        padding_mode="zeros",
667        device=None,
668        dtype=None,
669    ):
670        assert padding_mode != "reflect", "Conv3d does not support reflection padding"
671        factory_kwargs = {"device": device, "dtype": dtype}
672        kernel_size = _triple(kernel_size)
673        stride = _triple(stride)
674        padding = _triple(padding)
675        dilation = _triple(dilation)
676        # Subclasses of _ConvNd need to call _init rather than __init__. See
677        # discussion on PR #49702
678        super()._init(
679            in_channels,
680            out_channels,
681            kernel_size,
682            stride,
683            padding,
684            dilation,
685            False,
686            _triple(0),
687            groups,
688            bias,
689            padding_mode,
690            **factory_kwargs,
691        )
692
693    def _get_name(self):
694        return "QuantizedConv3d"
695
696    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
697        if self.padding_mode == "zeros":
698            self._packed_params = torch.ops.quantized.conv3d_prepack(
699                w, b, self.stride, self.padding, self.dilation, self.groups
700            )
701        else:
702            self._packed_params = torch.ops.quantized.conv3d_prepack(
703                w, b, self.stride, _triple(0), self.dilation, self.groups
704            )
705
706    def _weight_bias(self):
707        return self._packed_params.unpack()
708
709    def weight(self):
710        return self._weight_bias()[0]
711
712    def bias(self):
713        return self._weight_bias()[1]
714
715    def forward(self, input):
716        # Temporarily using len(shape) instead of ndim due to JIT issue
717        # https://github.com/pytorch/pytorch/issues/23890
718        if len(input.shape) != 5:
719            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
720        if self.padding_mode != "zeros":
721            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
722            input = F.pad(
723                input, _reversed_padding_repeated_twice, mode=self.padding_mode
724            )
725        return ops.quantized.conv3d(
726            input, self._packed_params, self.scale, self.zero_point
727        )
728
729    @classmethod
730    def from_float(cls, mod, use_precomputed_fake_quant=False):
731        r"""Creates a quantized module from a float module or qparams_dict.
732
733        Args:
734            mod (Module): a float module, either produced by torch.ao.quantization
735              utilities or provided by the user
736        """
737        return _ConvNd.from_float(
738            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
739        )
740
741
742# === Transposed Convolutions ===
743MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
744
745
746class _ConvTransposeNd(_ConvNd):
747    _FLOAT_MODULE = MOD
748
749    def __init__(
750        self,
751        in_channels,
752        out_channels,
753        kernel_size,
754        stride,
755        padding,
756        dilation,
757        transposed,
758        output_padding,
759        groups,
760        bias,
761        padding_mode,
762        device=None,
763        dtype=None,
764    ):
765        if padding_mode != "zeros":
766            raise ValueError(
767                f'Only "zeros" padding mode is supported for {self.__class__.__name__}'
768            )
769        factory_kwargs = {"device": device, "dtype": dtype}
770        # Subclasses of _ConvNd need to call _init rather than __init__. See
771        # discussion on PR #49702
772        super()._init(
773            in_channels,
774            out_channels,
775            kernel_size,
776            stride,
777            padding,
778            dilation,
779            transposed,
780            output_padding,
781            groups,
782            bias,
783            padding_mode,
784            **factory_kwargs,
785        )
786
787    def _input_padding(
788        self, kernel_size: List[int], dilation: List[int], padding: List[int]
789    ) -> List[int]:
790        res = torch.jit.annotate(List[int], [])
791        for kdx in range(len(kernel_size)):
792            pad = dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx]
793            res.append(pad)
794        return res
795
796    @classmethod
797    def from_float(cls, mod, use_precomputed_fake_quant=False):
798        r"""Creates a quantized module from a float module or qparams_dict.
799        Args:
800            mod (Module): a float module, either produced by torch.ao.quantization
801              utilities or provided by the user
802        """
803        # derived classes override cls._FLOAT_MODULE attribute
804        msg = (
805            " nnq."
806            + cls.__name__
807            + ".from_float only works for "
808            + cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
809        )
810        assert type(mod) == cls._FLOAT_MODULE, msg
811        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined."
812        weight_post_process = mod.qconfig.weight()
813        weight_post_process(mod.weight)
814        assert (
815            weight_post_process.dtype == torch.qint8
816        ), "Weight observer must have a dtype of qint8"
817        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
818        # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
819        qconv = cls(
820            mod.in_channels,
821            mod.out_channels,
822            mod.kernel_size,  # type: ignore[call-arg]
823            mod.stride,
824            mod.padding,
825            mod.output_padding,
826            mod.groups,
827            mod.bias is not None,
828            mod.dilation,
829            mod.padding_mode,
830        )
831        qconv.set_weight_bias(qweight, mod.bias)
832        if (
833            not hasattr(mod, "activation_post_process")
834            or mod.activation_post_process.dtype == torch.float
835        ):
836            return qconv  # dynamic quantization doesn't need scale/zero_point
837        else:
838            act_scale, act_zp = mod.activation_post_process.calculate_qparams()
839            qconv.scale = float(act_scale)
840            qconv.zero_point = int(act_zp)
841            return qconv
842
843    @staticmethod
844    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
845        r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
846        Args:
847            ref_qconvt (Module): a reference quantized  module, either produced by torch.ao.quantization
848                                 utilities or provided by the user
849            output_scale (float): scale for output Tensor
850            output_zero_point (int): zero point for output Tensor
851        """
852        qconv = cls(
853            ref_qconvt.in_channels,
854            ref_qconvt.out_channels,
855            ref_qconvt.kernel_size,  # type: ignore[arg-type]
856            ref_qconvt.stride,  # type: ignore[arg-type]
857            ref_qconvt.padding,  # type: ignore[arg-type]
858            ref_qconvt.output_padding,  # type: ignore[arg-type]
859            ref_qconvt.groups,
860            ref_qconvt.bias is not None,  # type: ignore[arg-type]
861            ref_qconvt.dilation,  # type: ignore[arg-type]
862            ref_qconvt.padding_mode,
863            device=ref_qconvt.weight.device,
864            dtype=ref_qconvt.weight.dtype,
865        )
866        qweight = ref_qconvt.get_quantized_weight()
867        qconv.set_weight_bias(qweight, ref_qconvt.bias)
868        qconv.scale = float(output_scale)
869        qconv.zero_point = int(output_zero_point)
870        return qconv
871
872
873class ConvTranspose1d(_ConvTransposeNd):
874    r"""Applies a 1D transposed convolution operator over an input image
875    composed of several input planes.
876    For details on input arguments, parameters, and implementation see
877    :class:`~torch.nn.ConvTranspose1d`.
878
879    .. note:: Currently only the QNNPACK engine is implemented.
880        Please, set the `torch.backends.quantized.engine = 'qnnpack'`
881
882    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`
883
884    Attributes:
885        weight (Tensor):     packed tensor derived from the learnable weight
886                             parameter.
887        scale (Tensor):      scalar for the output scale
888        zero_point (Tensor): scalar for the output zero point
889    See :class:`~torch.nn.ConvTranspose2d` for other attributes.
890
891    Examples::
892
893        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
894        >>> torch.backends.quantized.engine = 'qnnpack'
895        >>> from torch.ao.nn import quantized as nnq
896        >>> # With square kernels and equal stride
897        >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
898        >>> # non-square kernels and unequal stride and with padding
899        >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
900        >>> input = torch.randn(20, 16, 50)
901        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
902        >>> output = m(q_input)
903        >>> # exact output size can be also specified as an argument
904        >>> input = torch.randn(1, 16, 12)
905        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
906        >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
907        >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
908        >>> h = downsample(q_input)
909        >>> h.size()
910        torch.Size([1, 16, 6])
911        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
912        >>> output = upsample(h, output_size=input.size())
913        >>> output.size()
914        torch.Size([1, 16, 12])
915    """
916
917    _FLOAT_MODULE = nn.ConvTranspose1d
918
919    def __init__(
920        self,
921        in_channels,
922        out_channels,
923        kernel_size,
924        stride=1,
925        padding=0,
926        output_padding=0,
927        groups=1,
928        bias=True,
929        dilation=1,
930        padding_mode="zeros",
931        device=None,
932        dtype=None,
933    ):
934        factory_kwargs = {"device": device, "dtype": dtype}
935        kernel_size = _single(kernel_size)
936        stride = _single(stride)
937        padding = _single(padding)
938        dilation = _single(dilation)
939        output_padding = _single(output_padding)
940
941        super().__init__(
942            in_channels,
943            out_channels,
944            kernel_size,
945            stride,
946            padding,
947            dilation,
948            True,
949            output_padding,
950            groups,
951            bias,
952            padding_mode,
953            **factory_kwargs,
954        )
955
956    def _get_name(self):
957        return "QuantizedConvTranspose1d"
958
959    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
960        self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(
961            w,
962            b,
963            self.stride,
964            self.padding,
965            self.output_padding,
966            self.dilation,
967            self.groups,
968        )
969
970    def _weight_bias(self):
971        w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params)
972        return w, b
973
974    def weight(self):
975        (w, _) = self._weight_bias()
976        return w
977
978    def bias(self):
979        (_, b) = self._weight_bias()
980        return b
981
982    def forward(self, input):
983        # Temporarily using len(shape) instead of ndim due to JIT issue
984        # https://github.com/pytorch/pytorch/issues/23890
985        if len(input.shape) != 3:
986            raise ValueError("Input shape must be `(N, C, L)`!")
987        return torch.ops.quantized.conv_transpose1d(
988            input, self._packed_params, self.scale, self.zero_point
989        )
990
991    @classmethod
992    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
993        return _ConvTransposeNd.from_reference(
994            cls, ref_qconvt, output_scale, output_zero_point
995        )
996
997
998class ConvTranspose2d(_ConvTransposeNd):
999    r"""Applies a 2D transposed convolution operator over an input image
1000    composed of several input planes.
1001    For details on input arguments, parameters, and implementation see
1002    :class:`~torch.nn.ConvTranspose2d`.
1003
1004    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`
1005
1006    Attributes:
1007        weight (Tensor):     packed tensor derived from the learnable weight
1008                             parameter.
1009        scale (Tensor):      scalar for the output scale
1010        zero_point (Tensor): scalar for the output zero point
1011    See :class:`~torch.nn.ConvTranspose2d` for other attributes.
1012
1013    Examples::
1014
1015        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
1016        >>> # QNNPACK or FBGEMM as backend
1017        >>> torch.backends.quantized.engine = 'qnnpack'
1018        >>> # With square kernels and equal stride
1019        >>> import torch.ao.nn.quantized as nnq
1020        >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
1021        >>> # non-square kernels and unequal stride and with padding
1022        >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
1023        >>> input = torch.randn(20, 16, 50, 100)
1024        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
1025        >>> output = m(q_input)
1026        >>> # exact output size can be also specified as an argument
1027        >>> input = torch.randn(1, 16, 12, 12)
1028        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
1029        >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
1030        >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
1031        >>> h = downsample(q_input)
1032        >>> h.size()
1033        torch.Size([1, 16, 6, 6])
1034        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
1035        >>> output = upsample(h, output_size=input.size())
1036        >>> output.size()
1037        torch.Size([1, 16, 12, 12])
1038    """
1039
1040    _FLOAT_MODULE = nn.ConvTranspose2d
1041
1042    def __init__(
1043        self,
1044        in_channels,
1045        out_channels,
1046        kernel_size,
1047        stride=1,
1048        padding=0,
1049        output_padding=0,
1050        groups=1,
1051        bias=True,
1052        dilation=1,
1053        padding_mode="zeros",
1054        device=None,
1055        dtype=None,
1056    ):
1057        factory_kwargs = {"device": device, "dtype": dtype}
1058        kernel_size = _pair(kernel_size)
1059        stride = _pair(stride)
1060        padding = _pair(padding)
1061        dilation = _pair(dilation)
1062        output_padding = _pair(output_padding)
1063
1064        super().__init__(
1065            in_channels,
1066            out_channels,
1067            kernel_size,
1068            stride,
1069            padding,
1070            dilation,
1071            True,
1072            output_padding,
1073            groups,
1074            bias,
1075            padding_mode,
1076            **factory_kwargs,
1077        )
1078
1079    def _get_name(self):
1080        return "QuantizedConvTranspose2d"
1081
1082    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
1083        self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
1084            w,
1085            b,
1086            self.stride,
1087            self.padding,
1088            self.output_padding,
1089            self.dilation,
1090            self.groups,
1091        )
1092
1093    def _weight_bias(self):
1094        w, b = torch.ops.quantized.conv2d_unpack(self._packed_params)
1095        return w, b
1096
1097    def weight(self):
1098        (w, _) = self._weight_bias()
1099        return w
1100
1101    def bias(self):
1102        (_, b) = self._weight_bias()
1103        return b
1104
1105    def forward(self, input):
1106        # Temporarily using len(shape) instead of ndim due to JIT issue
1107        # https://github.com/pytorch/pytorch/issues/23890
1108        if len(input.shape) != 4:
1109            raise ValueError("Input shape must be `(N, C, H, W)`!")
1110        return ops.quantized.conv_transpose2d(
1111            input, self._packed_params, self.scale, self.zero_point
1112        )
1113
1114    @classmethod
1115    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
1116        return _ConvTransposeNd.from_reference(
1117            cls, ref_qconvt, output_scale, output_zero_point
1118        )
1119
1120
1121class ConvTranspose3d(_ConvTransposeNd):
1122    r"""Applies a 3D transposed convolution operator over an input image
1123    composed of several input planes.
1124    For details on input arguments, parameters, and implementation see
1125    :class:`~torch.nn.ConvTranspose3d`.
1126
1127    .. note:: Currently only the FBGEMM engine is implemented.
1128        Please, set the `torch.backends.quantized.engine = 'fbgemm'`
1129
1130    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`
1131
1132    Attributes:
1133        weight (Tensor):     packed tensor derived from the learnable weight
1134                             parameter.
1135        scale (Tensor):      scalar for the output scale
1136        zero_point (Tensor): scalar for the output zero point
1137    See :class:`~torch.nn.ConvTranspose3d` for other attributes.
1138
1139    Examples::
1140
1141        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
1142        >>> torch.backends.quantized.engine = 'fbgemm'
1143        >>> from torch.ao.nn import quantized as nnq
1144        >>> # With cubic kernels and equal stride
1145        >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
1146        >>> # non-cubic kernels and unequal stride and with padding
1147        >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
1148        >>> input = torch.randn(20, 16, 50, 100, 100)
1149        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
1150        >>> output = m(q_input)
1151        >>> # exact output size can be also specified as an argument
1152        >>> input = torch.randn(1, 16, 12, 12, 12)
1153        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
1154        >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
1155        >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
1156        >>> h = downsample(q_input)
1157        >>> h.size()
1158        torch.Size([1, 16, 6, 6, 6])
1159        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
1160        >>> output = upsample(h, output_size=input.size())
1161        >>> output.size()
1162        torch.Size([1, 16, 12, 12, 12])
1163    """
1164
1165    _FLOAT_MODULE = nn.ConvTranspose3d
1166
1167    def __init__(
1168        self,
1169        in_channels,
1170        out_channels,
1171        kernel_size,
1172        stride=1,
1173        padding=0,
1174        output_padding=0,
1175        groups=1,
1176        bias=True,
1177        dilation=1,
1178        padding_mode="zeros",
1179        device=None,
1180        dtype=None,
1181    ):
1182        factory_kwargs = {"device": device, "dtype": dtype}
1183        kernel_size = _triple(kernel_size)
1184        stride = _triple(stride)
1185        padding = _triple(padding)
1186        dilation = _triple(dilation)
1187        output_padding = _triple(output_padding)
1188
1189        super().__init__(
1190            in_channels,
1191            out_channels,
1192            kernel_size,
1193            stride,
1194            padding,
1195            dilation,
1196            True,
1197            output_padding,
1198            groups,
1199            bias,
1200            padding_mode,
1201            **factory_kwargs,
1202        )
1203
1204    def _get_name(self):
1205        return "QuantizedConvTranspose3d"
1206
1207    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
1208        self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(
1209            w,
1210            b,
1211            self.stride,
1212            self.padding,
1213            self.output_padding,
1214            self.dilation,
1215            self.groups,
1216        )
1217
1218    def _weight_bias(self):
1219        w, b = torch.ops.quantized.conv3d_unpack(self._packed_params)
1220        return w, b
1221
1222    def weight(self):
1223        (w, _) = self._weight_bias()
1224        return w
1225
1226    def bias(self):
1227        (_, b) = self._weight_bias()
1228        return b
1229
1230    def forward(self, input):
1231        # Temporarily using len(shape) instead of ndim due to JIT issue
1232        # https://github.com/pytorch/pytorch/issues/23890
1233        if len(input.shape) != 5:
1234            raise ValueError("Input shape must be `(N, C, T, H, W)`!")
1235        return ops.quantized.conv_transpose3d(
1236            input, self._packed_params, self.scale, self.zero_point
1237        )
1238
1239    @classmethod
1240    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
1241        return _ConvTransposeNd.from_reference(
1242            cls, ref_qconvt, output_scale, output_zero_point
1243        )
1244