xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import collections
4import itertools
5
6import torch
7from torch.nn.modules.module import _addindent
8
9
10__all__ = [
11    "WeightedQuantizedModule",
12]
13
14
15class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta):
16    """Wrapper for quantized modules than can be lowered from reference modules."""
17
18    @classmethod
19    @abc.abstractmethod
20    def from_reference(cls, ref_module, output_scale, output_zero_point):
21        raise NotImplementedError
22
23
24def _get_weight_observer(observer):
25    # FakeQuantize observer
26    if hasattr(observer, "activation_post_process"):
27        observer = observer.activation_post_process
28    # UniformQuantizationObserverBase observer
29    return observer
30
31
32def _needs_weight_clamping(observer, dtype):
33    observer = _get_weight_observer(observer)
34    if dtype in [torch.qint8, torch.quint8, torch.qint32]:
35        info = torch.iinfo(dtype)
36        return observer.quant_min > info.min or observer.quant_max < info.max
37    return False
38
39
40def _clamp_weights(qweight, observer, scale, zp):
41    if not _needs_weight_clamping(observer, qweight.dtype):
42        return qweight
43
44    observer = _get_weight_observer(observer)
45    min_, max_ = observer.quant_min, observer.quant_max
46
47    # Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet.
48    qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
49    qw_int_min = torch.clone(qweight.int_repr()).fill_(min_)
50    qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max)
51
52    if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
53        qweight = torch._make_per_tensor_quantized_tensor(
54            qw_int, scale.item(), zp.item()
55        )
56    elif observer.qscheme in [
57        torch.per_channel_symmetric,
58        torch.per_channel_affine,
59        torch.per_channel_affine_float_qparams,
60    ]:
61        qweight = torch._make_per_channel_quantized_tensor(
62            qw_int, scale, zp, axis=observer.ch_axis
63        )
64    else:
65        raise ValueError("Unexpected qscheme " + observer.qscheme)
66    return qweight
67
68
69def _quantize_weight(float_wt, observer):
70    wt_scale, wt_zp = observer.calculate_qparams()
71    if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
72        qweight = torch.quantize_per_tensor(
73            float_wt, float(wt_scale), int(wt_zp), torch.qint8
74        )
75        qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
76    elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
77        wt_axis = observer.ch_axis
78        qweight = torch.quantize_per_channel(
79            float_wt,
80            wt_scale.to(torch.double),
81            wt_zp.to(torch.int64),
82            wt_axis,
83            torch.qint8,
84        )
85        qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
86    elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
87        qweight = torch.quantize_per_channel(
88            float_wt,
89            wt_scale.to(torch.float),
90            wt_zp.to(torch.float),
91            observer.ch_axis,
92            observer.dtype,
93        )
94        qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
95    else:
96        raise ValueError("Unexpected qscheme " + observer.qscheme)
97    return qweight
98
99
100def _ntuple_from_first(n):
101    """Converts the argument to a tuple of size n
102    with the first element repeated."""
103
104    def parse(x):
105        while isinstance(x, collections.abc.Sequence):
106            if len(x) == n:
107                break
108            x = x[0]
109        return tuple(itertools.repeat(x, n))
110
111    return parse
112
113
114def _hide_packed_params_repr(self, params):
115    # We don't want to show `PackedParams` children, hence custom
116    # `__repr__`. This is the same as nn.Module.__repr__, except the check
117    # for the `params module`.
118    extra_lines = []
119    extra_repr = self.extra_repr()
120    # empty string will be split into list ['']
121    if extra_repr:
122        extra_lines = extra_repr.split("\n")
123    child_lines = []
124    for key, module in self._modules.items():
125        if isinstance(module, params):
126            continue
127        mod_str = repr(module)
128        mod_str = _addindent(mod_str, 2)
129        child_lines.append("(" + key + "): " + mod_str)
130    lines = extra_lines + child_lines
131
132    main_str = self._get_name() + "("
133    if lines:
134        # simple one-liner info, which most builtin Modules will use
135        if len(extra_lines) == 1 and not child_lines:
136            main_str += extra_lines[0]
137        else:
138            main_str += "\n  " + "\n  ".join(lines) + "\n"
139
140    main_str += ")"
141    return main_str
142
143
144_pair_from_first = _ntuple_from_first(2)
145