1# mypy: ignore-errors 2 3r"""Importing this file includes common utility methods for checking quantized 4tensors and modules. 5""" 6import numpy as np 7import torch 8from contextlib import contextmanager 9from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_PPC, IS_MACOS, IS_WINDOWS 10 11supported_qengines = torch.backends.quantized.supported_engines 12supported_qengines.remove('none') 13# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326 14# QNNPACK is not supported on PPC 15# QNNPACK throws ASAN heap-buffer-overflow error. 16if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_MACOS, IS_WINDOWS]): 17 supported_qengines.remove('qnnpack') 18 19def _conv_output_shape(input_size, kernel_size, padding, stride, dilation, 20 output_padding=0): 21 """Computes the output shape given convolution parameters.""" 22 return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1) 23 * (dilation - 1)) / stride) + 2 * output_padding + 1 24 25# Quantization references 26def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8): 27 """Quantizes a numpy array.""" 28 if qmin is None: 29 qmin = np.iinfo(dtype).min 30 if qmax is None: 31 qmax = np.iinfo(dtype).max 32 qx = np.round(x / scale + zero_point).astype(np.int64) 33 qx = np.clip(qx, qmin, qmax) 34 qx = qx.astype(dtype) 35 return qx 36 37 38def _dequantize(qx, scale, zero_point): 39 """Dequantizes a numpy array.""" 40 x = (qx.astype(float) - zero_point) * scale 41 return x 42 43 44def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8): 45 """Requantizes a numpy array, i.e., intermediate int32 or int16 values are 46 converted back to given type""" 47 qx = (x * multiplier).round() + zero_point 48 qx = np.clip(qx, qmin, qmax).astype(qtype) 49 return qx 50 51def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine): 52 """Calculate the dynamic quantization parameters (scale, zero_point) 53 according to the min and max element of the tensor""" 54 assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric) 55 if qscheme == torch.per_tensor_symmetric: 56 assert dtype == torch.qint8 57 if isinstance(X, torch.Tensor): 58 X = X.numpy() 59 if dtype == torch.qint8: 60 if reduce_range: 61 qmin, qmax = -64, 63 62 else: 63 qmin, qmax = -128, 127 64 else: # dtype == torch.quint8 65 if reduce_range: 66 qmin, qmax = 0, 127 67 else: 68 qmin, qmax = 0, 255 69 min_val = X.min() 70 max_val = X.max() 71 is_symmetric = (qscheme == torch.per_tensor_symmetric) 72 if min_val == max_val: 73 scale = 1.0 74 zero_point = 0 75 else: 76 if is_symmetric: 77 max_val = max(max_val, -min_val) 78 min_val = -max_val 79 scale = (max_val - min_val) / (qmax - qmin) 80 scale = max(scale, np.finfo(np.float32).eps) 81 zero_point = 0 82 else: 83 max_val = max(max_val, 0.0) 84 min_val = min(min_val, 0.0) 85 scale = (max_val - min_val) / (qmax - qmin) 86 scale = max(scale, np.finfo(np.float32).eps) 87 zero_point = qmin - round(min_val / scale) 88 zero_point = max(qmin, zero_point) 89 zero_point = min(qmax, zero_point) 90 return [float(scale), int(zero_point)] 91 92def _calculate_dynamic_per_channel_qparams(X, dtype): 93 """Calculate the dynamic quantization parameters (scale, zero_point) 94 according to the min and max element of the tensor""" 95 if isinstance(X, torch.Tensor): 96 X = X.numpy() 97 qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max 98 n_levels = qmax - qmin 99 scale = np.zeros(X.shape[0], dtype=np.float64) 100 zero_point = np.zeros(X.shape[0], dtype=np.int64) 101 for i in range(zero_point.shape[0]): 102 min_val = X.min() 103 max_val = X.max() 104 if min_val == max_val: 105 scale[i] = 1.0 106 zero_point[i] = 0 107 else: 108 max_val = max(max_val, 0.0) 109 min_val = min(min_val, 0.0) 110 scale[i] = (max_val - min_val) / n_levels 111 scale[i] = max(scale[i], np.finfo(np.float32).eps) 112 zero_point[i] = qmin - round(min_val / scale[i]) 113 zero_point[i] = max(qmin, zero_point[i]) 114 zero_point[i] = min(qmax, zero_point[i]) 115 116 return scale, zero_point 117 118def _snr(x, x_hat): 119 """Calculates the signal to noise ratio and returns the signal and noise 120 power, as well as the SNR in dB. 121 If the input is a list/tuple this function is called recursively on each 122 element. The result will have the same nested structure as the inputs. 123 124 Args: 125 x, x_hat: Either a tensor or a nested list/tuple of tensors. 126 Returns: 127 signal, noise, SNR(in dB): Either floats or a nested list of floats 128 """ 129 if isinstance(x, (list, tuple)): 130 assert len(x) == len(x_hat) 131 res = [] 132 for idx in range(len(x)): 133 res.append(_snr(x[idx], x_hat[idx])) 134 return res 135 if x_hat.is_quantized: 136 x_hat = x_hat.dequantize() 137 if x.is_quantized: 138 x = x.dequantize() 139 noise = (x - x_hat).norm() 140 if noise == 0: 141 return 0.0, float('inf'), float('inf') 142 signal = x.norm() 143 snr = signal / noise 144 snr_db = 20 * snr.log10() 145 return signal, noise, snr_db 146 147@contextmanager 148def override_quantized_engine(qengine): 149 previous = torch.backends.quantized.engine 150 torch.backends.quantized.engine = qengine 151 try: 152 yield 153 finally: 154 torch.backends.quantized.engine = previous 155 156@contextmanager 157def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack): 158 try: 159 if qengine_is_qnnpack: 160 torch._C._set_default_mobile_cpu_allocator() 161 yield 162 finally: 163 if qengine_is_qnnpack: 164 torch._C._unset_default_mobile_cpu_allocator() 165 166# TODO: Update all quantization tests to use this decorator. 167# Currently for some of the tests it seems to have inconsistent params 168# for fbgemm vs qnnpack. 169def override_qengines(qfunction): 170 def test_fn(*args, **kwargs): 171 for qengine in supported_qengines: 172 with override_quantized_engine(qengine): 173 # qfunction should not return anything. 174 qfunction(*args, **kwargs) 175 return test_fn 176 177def qengine_is_fbgemm(): 178 return torch.backends.quantized.engine == 'fbgemm' 179def qengine_is_qnnpack(): 180 return torch.backends.quantized.engine == 'qnnpack' 181def qengine_is_onednn(): 182 return torch.backends.quantized.engine == 'onednn' 183def qengine_is_x86(): 184 return torch.backends.quantized.engine == 'x86' 185 186# Helper function used to simulate per-channel fake-quant against any axis 187def _permute_to_axis_zero(X, axis): 188 new_axis_list = list(range(X.dim())) 189 new_axis_list[axis] = 0 190 new_axis_list[0] = axis 191 y = X.permute(tuple(new_axis_list)) 192 return y, new_axis_list 193 194# Reference method for fake quantize 195# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 196def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max): 197 dtype = X.dtype 198 X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis) 199 res = torch.zeros_like(X) 200 201 for i in range(X.size()[0]): 202 res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) + 203 per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i] 204 205 out = res.permute(tuple(permute_axis_list)) 206 return out.to(dtype) 207 208# Reference method for the gradient of the fake quantize operator 209# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 210def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max): 211 dtype = X.dtype 212 X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis) 213 Xq = torch.zeros_like(X) 214 for i in range(X.size()[0]): 215 Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i]) 216 Xq = Xq.permute(tuple(permute_axis_list)) 217 mask = (Xq >= quant_min) * (Xq <= quant_max) 218 res = torch.zeros_like(dY) 219 res[mask] = dY[mask] 220 return res.to(dtype) 221 222def to_tensor(X, device): 223 if not isinstance(X, torch.Tensor): 224 X = torch.tensor(X) 225 else: 226 X = X.clone().detach() 227 return X.to(device=torch.device(device), dtype=torch.float32) 228