xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/functional_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List
3
4import torch
5from torch import Tensor
6from torch._ops import ops
7
8
9__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]
10
11
12class FloatFunctional(torch.nn.Module):
13    r"""State collector class for float operations.
14
15    The instance of this class can be used instead of the ``torch.`` prefix for
16    some operations. See example usage below.
17
18    .. note::
19
20        This class does not provide a ``forward`` hook. Instead, you must use
21        one of the underlying functions (e.g. ``add``).
22
23    Examples::
24
25        >>> f_add = FloatFunctional()
26        >>> a = torch.tensor(3.0)
27        >>> b = torch.tensor(4.0)
28        >>> f_add.add(a, b)  # Equivalent to ``torch.add(a, b)``
29
30    Valid operation names:
31        - add
32        - cat
33        - mul
34        - add_relu
35        - add_scalar
36        - mul_scalar
37    """
38
39    def __init__(self) -> None:
40        super().__init__()
41        self.activation_post_process = torch.nn.Identity()
42
43    def forward(self, x):
44        raise RuntimeError(
45            "FloatFunctional is not intended to use the "
46            + "'forward'. Please use the underlying operation"
47        )
48
49    r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
50
51    def add(self, x: Tensor, y: Tensor) -> Tensor:
52        r = torch.add(x, y)
53        r = self.activation_post_process(r)
54        return r
55
56    r"""Operation equivalent to ``torch.add(Tensor, float)``"""
57
58    def add_scalar(self, x: Tensor, y: float) -> Tensor:
59        r = torch.add(x, y)
60        # Note: this operation is not observed because the observation is not
61        # needed for the quantized op.
62        return r
63
64    r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
65
66    def mul(self, x: Tensor, y: Tensor) -> Tensor:
67        r = torch.mul(x, y)
68        r = self.activation_post_process(r)
69        return r
70
71    r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
72
73    def mul_scalar(self, x: Tensor, y: float) -> Tensor:
74        r = torch.mul(x, y)
75        # Note: this operation is not observed because the observation is not
76        # needed for the quantized op.
77        return r
78
79    r"""Operation equivalent to ``torch.cat``"""
80
81    def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
82        r = torch.cat(x, dim=dim)
83        r = self.activation_post_process(r)
84        return r
85
86    r"""Operation equivalent to ``relu(torch.add(x,y))``"""
87
88    def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
89        r = torch.add(x, y)
90        r = torch.nn.functional.relu(r)
91        r = self.activation_post_process(r)
92        return r
93
94    r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
95
96    def matmul(self, x: Tensor, y: Tensor) -> Tensor:
97        r = torch.matmul(x, y)
98        r = self.activation_post_process(r)
99        return r
100
101
102class FXFloatFunctional(torch.nn.Module):
103    r"""module to replace FloatFunctional module before FX graph mode quantization,
104    since activation_post_process will be inserted in top level module directly
105
106    Valid operation names:
107        - add
108        - cat
109        - mul
110        - add_relu
111        - add_scalar
112        - mul_scalar
113    """
114
115    def forward(self, x):
116        raise RuntimeError(
117            "FloatFunctional is not intended to use the "
118            + "'forward'. Please use the underlying operation"
119        )
120
121    r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
122
123    def add(self, x: Tensor, y: Tensor) -> Tensor:
124        r = torch.add(x, y)
125        return r
126
127    r"""Operation equivalent to ``torch.add(Tensor, float)``"""
128
129    def add_scalar(self, x: Tensor, y: float) -> Tensor:
130        r = torch.add(x, y)
131        return r
132
133    r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
134
135    def mul(self, x: Tensor, y: Tensor) -> Tensor:
136        r = torch.mul(x, y)
137        return r
138
139    r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
140
141    def mul_scalar(self, x: Tensor, y: float) -> Tensor:
142        r = torch.mul(x, y)
143        return r
144
145    r"""Operation equivalent to ``torch.cat``"""
146
147    def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
148        r = torch.cat(x, dim=dim)
149        return r
150
151    r"""Operation equivalent to ``relu(torch.add(x,y))``"""
152
153    def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
154        r = torch.add(x, y)
155        r = torch.nn.functional.relu(r)
156        return r
157
158    r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
159
160    def matmul(self, x: Tensor, y: Tensor) -> Tensor:
161        r = torch.matmul(x, y)
162        return r
163
164
165class QFunctional(torch.nn.Module):
166    r"""Wrapper class for quantized operations.
167
168    The instance of this class can be used instead of the
169    ``torch.ops.quantized`` prefix. See example usage below.
170
171    .. note::
172
173        This class does not provide a ``forward`` hook. Instead, you must use
174        one of the underlying functions (e.g. ``add``).
175
176    Examples::
177
178        >>> q_add = QFunctional()
179        >>> # xdoctest: +SKIP
180        >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32)
181        >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32)
182        >>> q_add.add(a, b)  # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)``
183
184    Valid operation names:
185        - add
186        - cat
187        - mul
188        - add_relu
189        - add_scalar
190        - mul_scalar
191    """
192
193    def __init__(self) -> None:
194        super().__init__()
195        self.scale = 1.0
196        self.zero_point = 0
197        self.activation_post_process = torch.nn.Identity()
198
199    def _save_to_state_dict(self, destination, prefix, keep_vars):
200        super()._save_to_state_dict(destination, prefix, keep_vars)
201        destination[prefix + "scale"] = torch.tensor(self.scale)
202        destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
203
204    def _load_from_state_dict(
205        self,
206        state_dict,
207        prefix,
208        local_metadata,
209        strict,
210        missing_keys,
211        unexpected_keys,
212        error_msgs,
213    ):
214        self.scale = float(state_dict.pop(prefix + "scale"))
215        self.zero_point = int(state_dict.pop(prefix + "zero_point"))
216        super()._load_from_state_dict(
217            state_dict,
218            prefix,
219            local_metadata,
220            False,
221            missing_keys,
222            unexpected_keys,
223            error_msgs,
224        )
225
226    def _get_name(self):
227        return "QFunctional"
228
229    def extra_repr(self):
230        return f"scale={self.scale}, zero_point={self.zero_point}"
231
232    def forward(self, x):
233        raise RuntimeError(
234            "Functional is not intended to use the "
235            + "'forward'. Please use the underlying operation"
236        )
237
238    r"""Operation equivalent to ``torch.ops.quantized.add``"""
239
240    def add(self, x: Tensor, y: Tensor) -> Tensor:
241        r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point)
242        r = self.activation_post_process(r)
243        return r
244
245    r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""
246
247    def add_scalar(self, x: Tensor, y: float) -> Tensor:
248        r = ops.quantized.add_scalar(x, y)
249        # Note: this operation is not observed because the observation is not
250        # needed for the quantized op.
251        return r
252
253    r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
254
255    def mul(self, x: Tensor, y: Tensor) -> Tensor:
256        r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
257        r = self.activation_post_process(r)
258        return r
259
260    r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""
261
262    def mul_scalar(self, x: Tensor, y: float) -> Tensor:
263        r = ops.quantized.mul_scalar(x, y)
264        # Note: this operation is not observed because the observation is not
265        # needed for the quantized op.
266        return r
267
268    r"""Operation equivalent to ``torch.ops.quantized.cat``"""
269
270    def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
271        r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
272        r = self.activation_post_process(r)
273        return r
274
275    r"""Operation equivalent to ``torch.ops.quantized.add_relu``"""
276
277    def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
278        r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point)
279        r = self.activation_post_process(r)
280        return r
281
282    r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
283
284    def matmul(self, x: Tensor, y: Tensor) -> Tensor:
285        r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
286        # Note: this operation is not observed because the observation is not
287        # needed for the quantized op.
288        return r
289
290    @classmethod
291    def from_float(cls, mod, use_precomputed_fake_quant=False):
292        assert (
293            type(mod) == FloatFunctional
294        ), "QFunctional.from_float expects an instance of FloatFunctional"
295        scale, zero_point = mod.activation_post_process.calculate_qparams()  # type: ignore[operator]
296        new_mod = QFunctional()
297        new_mod.scale = float(scale)
298        new_mod.zero_point = int(zero_point)
299        return new_mod
300