xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_quantized.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3r"""Importing this file includes common utility methods for checking quantized
4tensors and modules.
5"""
6import numpy as np
7import torch
8from contextlib import contextmanager
9from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_PPC, IS_MACOS, IS_WINDOWS
10
11supported_qengines = torch.backends.quantized.supported_engines
12supported_qengines.remove('none')
13# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
14# QNNPACK is not supported on PPC
15# QNNPACK throws ASAN heap-buffer-overflow error.
16if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_MACOS, IS_WINDOWS]):
17    supported_qengines.remove('qnnpack')
18
19def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
20                       output_padding=0):
21    """Computes the output shape given convolution parameters."""
22    return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
23                     * (dilation - 1)) / stride) + 2 * output_padding + 1
24
25# Quantization references
26def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
27    """Quantizes a numpy array."""
28    if qmin is None:
29        qmin = np.iinfo(dtype).min
30    if qmax is None:
31        qmax = np.iinfo(dtype).max
32    qx = np.round(x / scale + zero_point).astype(np.int64)
33    qx = np.clip(qx, qmin, qmax)
34    qx = qx.astype(dtype)
35    return qx
36
37
38def _dequantize(qx, scale, zero_point):
39    """Dequantizes a numpy array."""
40    x = (qx.astype(float) - zero_point) * scale
41    return x
42
43
44def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
45    """Requantizes a numpy array, i.e., intermediate int32 or int16 values are
46    converted back to given type"""
47    qx = (x * multiplier).round() + zero_point
48    qx = np.clip(qx, qmin, qmax).astype(qtype)
49    return qx
50
51def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
52    """Calculate the dynamic quantization parameters (scale, zero_point)
53    according to the min and max element of the tensor"""
54    assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
55    if qscheme == torch.per_tensor_symmetric:
56        assert dtype == torch.qint8
57    if isinstance(X, torch.Tensor):
58        X = X.numpy()
59    if dtype == torch.qint8:
60        if reduce_range:
61            qmin, qmax = -64, 63
62        else:
63            qmin, qmax = -128, 127
64    else:  # dtype == torch.quint8
65        if reduce_range:
66            qmin, qmax = 0, 127
67        else:
68            qmin, qmax = 0, 255
69    min_val = X.min()
70    max_val = X.max()
71    is_symmetric = (qscheme == torch.per_tensor_symmetric)
72    if min_val == max_val:
73        scale = 1.0
74        zero_point = 0
75    else:
76        if is_symmetric:
77            max_val = max(max_val, -min_val)
78            min_val = -max_val
79            scale = (max_val - min_val) / (qmax - qmin)
80            scale = max(scale, np.finfo(np.float32).eps)
81            zero_point = 0
82        else:
83            max_val = max(max_val, 0.0)
84            min_val = min(min_val, 0.0)
85            scale = (max_val - min_val) / (qmax - qmin)
86            scale = max(scale, np.finfo(np.float32).eps)
87            zero_point = qmin - round(min_val / scale)
88            zero_point = max(qmin, zero_point)
89            zero_point = min(qmax, zero_point)
90    return [float(scale), int(zero_point)]
91
92def _calculate_dynamic_per_channel_qparams(X, dtype):
93    """Calculate the dynamic quantization parameters (scale, zero_point)
94    according to the min and max element of the tensor"""
95    if isinstance(X, torch.Tensor):
96        X = X.numpy()
97    qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
98    n_levels = qmax - qmin
99    scale = np.zeros(X.shape[0], dtype=np.float64)
100    zero_point = np.zeros(X.shape[0], dtype=np.int64)
101    for i in range(zero_point.shape[0]):
102        min_val = X.min()
103        max_val = X.max()
104        if min_val == max_val:
105            scale[i] = 1.0
106            zero_point[i] = 0
107        else:
108            max_val = max(max_val, 0.0)
109            min_val = min(min_val, 0.0)
110            scale[i] = (max_val - min_val) / n_levels
111            scale[i] = max(scale[i], np.finfo(np.float32).eps)
112            zero_point[i] = qmin - round(min_val / scale[i])
113            zero_point[i] = max(qmin, zero_point[i])
114            zero_point[i] = min(qmax, zero_point[i])
115
116    return scale, zero_point
117
118def _snr(x, x_hat):
119    """Calculates the signal to noise ratio and returns the signal and noise
120    power, as well as the SNR in dB.
121    If the input is a list/tuple this function is called recursively on each
122    element. The result will have the same nested structure as the inputs.
123
124    Args:
125        x, x_hat: Either a tensor or a nested list/tuple of tensors.
126    Returns:
127        signal, noise, SNR(in dB): Either floats or a nested list of floats
128    """
129    if isinstance(x, (list, tuple)):
130        assert len(x) == len(x_hat)
131        res = []
132        for idx in range(len(x)):
133            res.append(_snr(x[idx], x_hat[idx]))
134        return res
135    if x_hat.is_quantized:
136        x_hat = x_hat.dequantize()
137    if x.is_quantized:
138        x = x.dequantize()
139    noise = (x - x_hat).norm()
140    if noise == 0:
141        return 0.0, float('inf'), float('inf')
142    signal = x.norm()
143    snr = signal / noise
144    snr_db = 20 * snr.log10()
145    return signal, noise, snr_db
146
147@contextmanager
148def override_quantized_engine(qengine):
149    previous = torch.backends.quantized.engine
150    torch.backends.quantized.engine = qengine
151    try:
152        yield
153    finally:
154        torch.backends.quantized.engine = previous
155
156@contextmanager
157def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
158    try:
159        if qengine_is_qnnpack:
160            torch._C._set_default_mobile_cpu_allocator()
161        yield
162    finally:
163        if qengine_is_qnnpack:
164            torch._C._unset_default_mobile_cpu_allocator()
165
166# TODO: Update all quantization tests to use this decorator.
167# Currently for some of the tests it seems to have inconsistent params
168# for fbgemm vs qnnpack.
169def override_qengines(qfunction):
170    def test_fn(*args, **kwargs):
171        for qengine in supported_qengines:
172            with override_quantized_engine(qengine):
173                # qfunction should not return anything.
174                qfunction(*args, **kwargs)
175    return test_fn
176
177def qengine_is_fbgemm():
178    return torch.backends.quantized.engine == 'fbgemm'
179def qengine_is_qnnpack():
180    return torch.backends.quantized.engine == 'qnnpack'
181def qengine_is_onednn():
182    return torch.backends.quantized.engine == 'onednn'
183def qengine_is_x86():
184    return torch.backends.quantized.engine == 'x86'
185
186# Helper function used to simulate per-channel fake-quant against any axis
187def _permute_to_axis_zero(X, axis):
188    new_axis_list = list(range(X.dim()))
189    new_axis_list[axis] = 0
190    new_axis_list[0] = axis
191    y = X.permute(tuple(new_axis_list))
192    return y, new_axis_list
193
194# Reference method for fake quantize
195# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
196def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
197    dtype = X.dtype
198    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
199    res = torch.zeros_like(X)
200
201    for i in range(X.size()[0]):
202        res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
203                  per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
204
205    out = res.permute(tuple(permute_axis_list))
206    return out.to(dtype)
207
208# Reference method for the gradient of the fake quantize operator
209# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
210def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
211    dtype = X.dtype
212    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
213    Xq = torch.zeros_like(X)
214    for i in range(X.size()[0]):
215        Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
216    Xq = Xq.permute(tuple(permute_axis_list))
217    mask = (Xq >= quant_min) * (Xq <= quant_max)
218    res = torch.zeros_like(dY)
219    res[mask] = dY[mask]
220    return res.to(dtype)
221
222def to_tensor(X, device):
223    if not isinstance(X, torch.Tensor):
224        X = torch.tensor(X)
225    else:
226        X = X.clone().detach()
227    return X.to(device=torch.device(device), dtype=torch.float32)
228