xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/reference/modules/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, List, Optional
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7from torch.nn.common_types import _size_1_t
8
9from .utils import ReferenceQuantizedModule
10
11
12__all__ = [
13    "Conv1d",
14    "Conv2d",
15    "Conv3d",
16    "ConvTranspose1d",
17    "ConvTranspose2d",
18    "ConvTranspose3d",
19]
20
21
22class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule):
23    """A reference version of nn.quantized.Conv2d
24    we will not pack the parameters in this module, since weight packing is an
25    optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
26    this is useful when user want to use this module in other backends like Glow.
27    """
28
29    __annotations__ = {"bias": Optional[torch.Tensor]}
30    _IS_REFERENCE = True
31
32    @staticmethod
33    def from_float(cls, float_conv, weight_qparams):
34        qref_conv = cls(
35            float_conv.in_channels,
36            float_conv.out_channels,
37            float_conv.kernel_size,  # type: ignore[arg-type]
38            float_conv.stride,  # type: ignore[arg-type]
39            float_conv.padding,  # type: ignore[arg-type]
40            float_conv.dilation,  # type: ignore[arg-type]
41            float_conv.groups,
42            float_conv.bias is not None,  # type: ignore[arg-type]
43            float_conv.padding_mode,
44            device=float_conv.weight.device,
45            dtype=float_conv.weight.dtype,
46            weight_qparams=weight_qparams,
47        )
48        qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
49        if float_conv.bias is not None:
50            qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
51        return qref_conv
52
53
54class Conv1d(_ConvNd, nn.Conv1d):
55    def __init__(
56        self,
57        in_channels: int,
58        out_channels: int,
59        kernel_size: _size_1_t,
60        stride: _size_1_t = 1,
61        padding: _size_1_t = 0,
62        dilation: _size_1_t = 1,
63        groups: int = 1,
64        bias: bool = True,
65        padding_mode: str = "zeros",
66        device=None,
67        dtype=None,
68        weight_qparams: Optional[Dict[str, Any]] = None,
69    ):
70        nn.Conv1d.__init__(
71            self,
72            in_channels,
73            out_channels,
74            kernel_size,
75            stride,
76            padding,
77            dilation,
78            groups,
79            bias,
80            padding_mode,
81            device,
82            dtype,
83        )
84        self._init_weight_qparams(weight_qparams, device)
85
86    def forward(self, x: torch.Tensor) -> torch.Tensor:
87        """
88        we have:
89        w(float) -- quant - dequant \
90        x(float) ------------- F.conv1d ---
91
92        In the full model, we will see
93        w(float) -- quant - *dequant \
94        x -- quant --- *dequant --  *F.conv1d --- *quant - dequant
95        and the backend should be able to fuse the ops with `*` into a quantized conv1d
96        """
97        weight_quant_dequant = self.get_weight()
98        result = F.conv1d(
99            x,
100            weight_quant_dequant,
101            self.bias,
102            self.stride,
103            self.padding,
104            self.dilation,
105            self.groups,
106        )
107        return result
108
109    def _get_name(self):
110        return "QuantizedConv1d(Reference)"
111
112    @classmethod
113    def from_float(cls, float_conv, weight_qparams):
114        return _ConvNd.from_float(cls, float_conv, weight_qparams)
115
116
117class Conv2d(_ConvNd, nn.Conv2d):
118    def __init__(
119        self,
120        in_channels,
121        out_channels,
122        kernel_size,
123        stride=1,
124        padding=0,
125        dilation=1,
126        groups=1,
127        bias=True,
128        padding_mode="zeros",
129        device=None,
130        dtype=None,
131        weight_qparams: Optional[Dict[str, Any]] = None,
132    ):
133        nn.Conv2d.__init__(
134            self,
135            in_channels,
136            out_channels,
137            kernel_size,
138            stride,
139            padding,
140            dilation,
141            groups,
142            bias,
143            padding_mode,
144            device,
145            dtype,
146        )
147        self._init_weight_qparams(weight_qparams, device)
148
149    def forward(self, x: torch.Tensor) -> torch.Tensor:
150        """
151        we have:
152        w(float) -- quant - dequant \
153        x(float) ------------- F.conv2d ---
154
155        In the full model, we will see
156        w(float) -- quant - *dequant \
157        x -- quant --- *dequant --  *F.conv2d --- *quant - dequant
158        and the backend should be able to fuse the ops with `*` into a quantized conv2d
159        """
160        weight_quant_dequant = self.get_weight()
161        result = F.conv2d(
162            x,
163            weight_quant_dequant,
164            self.bias,
165            self.stride,
166            self.padding,
167            self.dilation,
168            self.groups,
169        )
170        return result
171
172    def _get_name(self):
173        return "QuantizedConv2d(Reference)"
174
175    @classmethod
176    def from_float(cls, float_conv, weight_qparams):
177        return _ConvNd.from_float(cls, float_conv, weight_qparams)
178
179
180class Conv3d(_ConvNd, nn.Conv3d):
181    def __init__(
182        self,
183        in_channels,
184        out_channels,
185        kernel_size,
186        stride=1,
187        padding=0,
188        dilation=1,
189        groups=1,
190        bias=True,
191        padding_mode="zeros",
192        device=None,
193        dtype=None,
194        weight_qparams: Optional[Dict[str, Any]] = None,
195    ):
196        nn.Conv3d.__init__(
197            self,
198            in_channels,
199            out_channels,
200            kernel_size,
201            stride,
202            padding,
203            dilation,
204            groups,
205            bias,
206            padding_mode,
207            device,
208            dtype,
209        )
210        self._init_weight_qparams(weight_qparams, device)
211
212    def forward(self, x: torch.Tensor) -> torch.Tensor:
213        """
214        we have:
215        w(float) -- quant - dequant \
216        x(float) ------------- F.conv3d ---
217
218        In the full model, we will see
219        w(float) -- quant - *dequant \
220        x -- quant --- *dequant --  *F.conv3d --- *quant - dequant
221        and the backend should be able to fuse the ops with `*` into a quantized conv3d
222        """
223        weight_quant_dequant = self.get_weight()
224        result = F.conv3d(
225            x,
226            weight_quant_dequant,
227            self.bias,
228            self.stride,
229            self.padding,
230            self.dilation,
231            self.groups,
232        )
233        return result
234
235    def _get_name(self):
236        return "QuantizedConv3d(Reference)"
237
238    @classmethod
239    def from_float(cls, float_conv, weight_qparams):
240        return _ConvNd.from_float(cls, float_conv, weight_qparams)
241
242
243class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd):
244    """A reference version of nn.quantized.ConvTranspose2d
245    we will not pack the parameters in this module, since weight packing is an
246    optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
247    this is useful when user want to use this module in other backends like Glow.
248    """
249
250    @staticmethod
251    def from_float(cls, float_conv, weight_qparams):
252        qref_conv = cls(
253            float_conv.in_channels,
254            float_conv.out_channels,
255            float_conv.kernel_size,  # type: ignore[arg-type]
256            float_conv.stride,  # type: ignore[arg-type]
257            float_conv.padding,  # type: ignore[arg-type]
258            float_conv.output_padding,  # type: ignore[arg-type]
259            float_conv.groups,
260            float_conv.bias is not None,  # type: ignore[arg-type]
261            float_conv.dilation,  # type: ignore[arg-type]
262            float_conv.padding_mode,
263            device=float_conv.weight.device,
264            dtype=float_conv.weight.dtype,
265            weight_qparams=weight_qparams,
266        )
267        qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
268        if float_conv.bias is not None:
269            qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
270        return qref_conv
271
272
273class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d):
274    def __init__(
275        self,
276        in_channels: int,
277        out_channels: int,
278        kernel_size: _size_1_t,
279        stride: _size_1_t = 1,
280        padding: _size_1_t = 0,
281        output_padding: _size_1_t = 0,
282        groups: int = 1,
283        bias: bool = True,
284        dilation: _size_1_t = 1,
285        padding_mode: str = "zeros",
286        device=None,
287        dtype=None,
288        weight_qparams: Optional[Dict[str, Any]] = None,
289    ):
290        nn.ConvTranspose1d.__init__(
291            self,
292            in_channels,
293            out_channels,
294            kernel_size,
295            stride,
296            padding,
297            output_padding,
298            groups,
299            bias,
300            dilation,
301            padding_mode,
302            device,
303            dtype,
304        )
305        self._init_weight_qparams(weight_qparams, device)
306
307    def forward(
308        self, x: torch.Tensor, output_size: Optional[List[int]] = None
309    ) -> torch.Tensor:
310        """
311        we have:
312        w(float) -- quant - dequant \
313        x(float) ------------- F.convTranspose1d ---
314        In the full model, we will see
315        w(float) -- quant - *dequant \
316        x -- quant --- *dequant --  *F.convTranspose1d --- *quant - dequant
317        and the backend should be able to fuse the ops with `*` into a quantized conv1d
318        """
319
320        assert isinstance(self.padding, tuple)
321        # One cannot replace List by Tuple or Sequence in "_output_padding" because
322        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
323        output_padding = self._output_padding(
324            input,  # type: ignore[arg-type]
325            output_size,
326            self.stride,  # type: ignore[arg-type]
327            self.padding,  # type: ignore[arg-type]
328            self.kernel_size,  # type: ignore[arg-type]
329            self.dilation,  # type: ignore[arg-type]
330        )
331
332        weight_quant_dequant = self.get_weight()
333        result = F.conv_transpose1d(
334            x,
335            weight_quant_dequant,
336            self.bias,
337            self.stride,
338            self.padding,
339            output_padding,
340            self.groups,
341            self.dilation,
342        )
343        return result
344
345    def _get_name(self):
346        return "QuantizedConvTranspose1d(Reference)"
347
348    @classmethod
349    def from_float(cls, float_conv, weight_qparams):
350        return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
351
352
353class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
354    def __init__(
355        self,
356        in_channels,
357        out_channels,
358        kernel_size,
359        stride=1,
360        padding=0,
361        output_padding=0,
362        groups=1,
363        bias=True,
364        dilation=1,
365        padding_mode="zeros",
366        device=None,
367        dtype=None,
368        weight_qparams: Optional[Dict[str, Any]] = None,
369    ):
370        nn.ConvTranspose2d.__init__(
371            self,
372            in_channels,
373            out_channels,
374            kernel_size,
375            stride,
376            padding,
377            output_padding,
378            groups,
379            bias,
380            dilation,
381            padding_mode,
382            device,
383            dtype,
384        )
385        self._init_weight_qparams(weight_qparams, device)
386
387    def forward(
388        self, x: torch.Tensor, output_size: Optional[List[int]] = None
389    ) -> torch.Tensor:
390        """
391        we have:
392        w(float) -- quant - dequant \
393        x(float) ------------- F.convTranspose2d ---
394        In the full model, we will see
395        w(float) -- quant - *dequant \
396        x -- quant --- *dequant --  *F.convTranspose2d --- *quant - dequant
397        and the backend should be able to fuse the ops with `*` into a quantized conv2d
398        """
399        assert isinstance(self.padding, tuple)
400        # One cannot replace List by Tuple or Sequence in "_output_padding" because
401        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
402
403        output_padding = self._output_padding(
404            input,  # type: ignore[arg-type]
405            output_size,
406            self.stride,  # type: ignore[arg-type]
407            self.padding,  # type: ignore[arg-type]
408            self.kernel_size,  # type: ignore[arg-type]
409            self.dilation,  # type: ignore[arg-type]
410        )
411
412        weight_quant_dequant = self.get_weight()
413        result = F.conv_transpose2d(
414            x,
415            weight_quant_dequant,
416            self.bias,
417            self.stride,
418            self.padding,
419            output_padding,
420            self.groups,
421            self.dilation,
422        )
423
424        return result
425
426    def _get_name(self):
427        return "QuantizedConvTranspose2d(Reference)"
428
429    @classmethod
430    def from_float(cls, float_conv, weight_qparams):
431        return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
432
433
434class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
435    def __init__(
436        self,
437        in_channels,
438        out_channels,
439        kernel_size,
440        stride=1,
441        padding=0,
442        output_padding=0,
443        groups=1,
444        bias=True,
445        dilation=1,
446        padding_mode="zeros",
447        device=None,
448        dtype=None,
449        weight_qparams: Optional[Dict[str, Any]] = None,
450    ):
451        nn.ConvTranspose3d.__init__(
452            self,
453            in_channels,
454            out_channels,
455            kernel_size,
456            stride,
457            padding,
458            output_padding,
459            groups,
460            bias,
461            dilation,
462            padding_mode,
463            device,
464            dtype,
465        )
466        self._init_weight_qparams(weight_qparams, device)
467
468    def forward(
469        self, x: torch.Tensor, output_size: Optional[List[int]] = None
470    ) -> torch.Tensor:
471        """
472        we have:
473        w(float) -- quant - dequant \
474        x(float) ------------- F.convTranspose3d ---
475        In the full model, we will see
476        w(float) -- quant - *dequant \
477        x -- quant --- *dequant --  *F.convTranspose3d --- *quant - dequant
478        and the backend should be able to fuse the ops with `*` into a quantized conv3d
479        """
480
481        assert isinstance(self.padding, tuple)
482        # One cannot replace List by Tuple or Sequence in "_output_padding" because
483        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
484        output_padding = self._output_padding(
485            input,  # type: ignore[arg-type]
486            output_size,
487            self.stride,  # type: ignore[arg-type]
488            self.padding,  # type: ignore[arg-type]
489            self.kernel_size,  # type: ignore[arg-type]
490            self.dilation,  # type: ignore[arg-type]
491        )
492
493        weight_quant_dequant = self.get_weight()
494        result = F.conv_transpose3d(
495            x,
496            weight_quant_dequant,
497            self.bias,
498            self.stride,
499            self.padding,
500            output_padding,
501            self.groups,
502            self.dilation,
503        )
504        return result
505
506    def _get_name(self):
507        return "QuantizedConvTranspose3d(Reference)"
508
509    @classmethod
510    def from_float(cls, float_conv, weight_qparams):
511        return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
512