xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/normalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4
5__all__ = [
6    "LayerNorm",
7    "GroupNorm",
8    "InstanceNorm1d",
9    "InstanceNorm2d",
10    "InstanceNorm3d",
11]
12
13
14class LayerNorm(torch.nn.LayerNorm):
15    r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
16
17    Additional args:
18        * **scale** - quantization scale of the output, type: double.
19        * **zero_point** - quantization zero point of the output, type: long.
20
21    """
22
23    def __init__(
24        self,
25        normalized_shape,
26        weight,
27        bias,
28        scale,
29        zero_point,
30        eps=1e-5,
31        elementwise_affine=True,
32        device=None,
33        dtype=None,
34    ) -> None:
35        factory_kwargs = {"device": device, "dtype": dtype}
36        super().__init__(
37            normalized_shape,
38            eps=eps,
39            elementwise_affine=elementwise_affine,
40            **factory_kwargs,
41        )
42        self.weight = weight
43        self.bias = bias
44        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
45        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
46
47    def forward(self, input):
48        return torch.ops.quantized.layer_norm(
49            input,
50            self.normalized_shape,
51            weight=self.weight,
52            bias=self.bias,
53            eps=self.eps,
54            output_scale=self.scale,
55            output_zero_point=self.zero_point,
56        )
57
58    def _get_name(self):
59        return "QuantizedLayerNorm"
60
61    @classmethod
62    def from_float(cls, mod, use_precomputed_fake_quant=False):
63        scale, zero_point = mod.activation_post_process.calculate_qparams()
64        new_mod = cls(
65            mod.normalized_shape,
66            mod.weight,
67            mod.bias,
68            float(scale),
69            int(zero_point),
70            mod.eps,
71            mod.elementwise_affine,
72        )
73        return new_mod
74
75    @classmethod
76    def from_reference(cls, mod, scale, zero_point):
77        return cls(
78            mod.normalized_shape,
79            mod.weight,
80            mod.bias,
81            float(scale),
82            int(zero_point),
83            mod.eps,
84            mod.elementwise_affine,
85        )
86
87
88class GroupNorm(torch.nn.GroupNorm):
89    r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
90
91    Additional args:
92        * **scale** - quantization scale of the output, type: double.
93        * **zero_point** - quantization zero point of the output, type: long.
94
95    """
96    __constants__ = ["num_groups", "num_channels", "eps", "affine"]
97
98    def __init__(
99        self,
100        num_groups,
101        num_channels,
102        weight,
103        bias,
104        scale,
105        zero_point,
106        eps=1e-5,
107        affine=True,
108        device=None,
109        dtype=None,
110    ) -> None:
111        factory_kwargs = {"device": device, "dtype": dtype}
112        super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
113        self.weight = weight
114        self.bias = bias
115        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
116        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
117
118    def forward(self, input):
119        return torch.ops.quantized.group_norm(
120            input,
121            self.num_groups,
122            self.weight,
123            self.bias,
124            self.eps,
125            self.scale,
126            self.zero_point,
127        )
128
129    def _get_name(self):
130        return "QuantizedGroupNorm"
131
132    @classmethod
133    def from_float(cls, mod, use_precomputed_fake_quant=False):
134        scale, zero_point = mod.activation_post_process.calculate_qparams()
135        new_mod = cls(
136            mod.num_groups,
137            mod.num_channels,
138            mod.weight,
139            mod.bias,
140            float(scale),
141            int(zero_point),
142            mod.eps,
143            mod.affine,
144        )
145        return new_mod
146
147
148class InstanceNorm1d(torch.nn.InstanceNorm1d):
149    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
150
151    Additional args:
152        * **scale** - quantization scale of the output, type: double.
153        * **zero_point** - quantization zero point of the output, type: long.
154
155    """
156
157    def __init__(
158        self,
159        num_features,
160        weight,
161        bias,
162        scale,
163        zero_point,
164        eps=1e-5,
165        momentum=0.1,
166        affine=False,
167        track_running_stats=False,
168        device=None,
169        dtype=None,
170    ) -> None:
171        factory_kwargs = {"device": device, "dtype": dtype}
172        super().__init__(
173            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
174        )
175        self.weight = weight
176        self.bias = bias
177        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
178        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
179
180    def forward(self, input):
181        return torch.ops.quantized.instance_norm(
182            input, self.weight, self.bias, self.eps, self.scale, self.zero_point
183        )
184
185    def _get_name(self):
186        return "QuantizedInstanceNorm1d"
187
188    @classmethod
189    def from_float(cls, mod, use_precomputed_fake_quant=False):
190        scale, zero_point = mod.activation_post_process.calculate_qparams()
191        new_mod = cls(
192            mod.num_features,
193            mod.weight,
194            mod.bias,
195            float(scale),
196            int(zero_point),
197            mod.eps,
198            mod.affine,
199        )
200        return new_mod
201
202    @classmethod
203    def from_reference(cls, mod, scale, zero_point):
204        return cls(
205            mod.num_features,
206            mod.weight,
207            mod.bias,
208            float(scale),
209            int(zero_point),
210            mod.eps,
211            mod.affine,
212        )
213
214
215class InstanceNorm2d(torch.nn.InstanceNorm2d):
216    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
217
218    Additional args:
219        * **scale** - quantization scale of the output, type: double.
220        * **zero_point** - quantization zero point of the output, type: long.
221
222    """
223
224    def __init__(
225        self,
226        num_features,
227        weight,
228        bias,
229        scale,
230        zero_point,
231        eps=1e-5,
232        momentum=0.1,
233        affine=False,
234        track_running_stats=False,
235        device=None,
236        dtype=None,
237    ) -> None:
238        factory_kwargs = {"device": device, "dtype": dtype}
239        super().__init__(
240            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
241        )
242        self.weight = weight
243        self.bias = bias
244        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
245        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
246
247    def forward(self, input):
248        return torch.ops.quantized.instance_norm(
249            input, self.weight, self.bias, self.eps, self.scale, self.zero_point
250        )
251
252    def _get_name(self):
253        return "QuantizedInstanceNorm2d"
254
255    @classmethod
256    def from_float(cls, mod, use_precomputed_fake_quant=False):
257        scale, zero_point = mod.activation_post_process.calculate_qparams()
258        new_mod = cls(
259            mod.num_features,
260            mod.weight,
261            mod.bias,
262            float(scale),
263            int(zero_point),
264            mod.eps,
265            mod.affine,
266        )
267        return new_mod
268
269    @classmethod
270    def from_reference(cls, mod, scale, zero_point):
271        return cls(
272            mod.num_features,
273            mod.weight,
274            mod.bias,
275            float(scale),
276            int(zero_point),
277            mod.eps,
278            mod.affine,
279        )
280
281
282class InstanceNorm3d(torch.nn.InstanceNorm3d):
283    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
284
285    Additional args:
286        * **scale** - quantization scale of the output, type: double.
287        * **zero_point** - quantization zero point of the output, type: long.
288
289    """
290
291    def __init__(
292        self,
293        num_features,
294        weight,
295        bias,
296        scale,
297        zero_point,
298        eps=1e-5,
299        momentum=0.1,
300        affine=False,
301        track_running_stats=False,
302        device=None,
303        dtype=None,
304    ) -> None:
305        factory_kwargs = {"device": device, "dtype": dtype}
306        super().__init__(
307            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
308        )
309        self.weight = weight
310        self.bias = bias
311        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
312        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
313
314    def forward(self, input):
315        return torch.ops.quantized.instance_norm(
316            input, self.weight, self.bias, self.eps, self.scale, self.zero_point
317        )
318
319    def _get_name(self):
320        return "QuantizedInstanceNorm3d"
321
322    @classmethod
323    def from_float(cls, mod, use_precomputed_fake_quant=False):
324        scale, zero_point = mod.activation_post_process.calculate_qparams()
325        new_mod = cls(
326            mod.num_features,
327            mod.weight,
328            mod.bias,
329            float(scale),
330            int(zero_point),
331            mod.eps,
332            mod.affine,
333        )
334        return new_mod
335
336    @classmethod
337    def from_reference(cls, mod, scale, zero_point):
338        return cls(
339            mod.num_features,
340            mod.weight,
341            mod.bias,
342            float(scale),
343            int(zero_point),
344            mod.eps,
345            mod.affine,
346        )
347